diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt index 69b02c700e32ce..602b77d3b853a5 100644 --- a/.ci/docker/aotriton_version.txt +++ b/.ci/docker/aotriton_version.txt @@ -1,5 +1,5 @@ -0.6b +0.7b manylinux_2_17 rocm6.2 -7f07e8a1cb1f99627eb6d77f5c0e9295c775f3c7 -e4ab195d2bd19e939c675a13280c29714c6ef9f2cf420690da150fa0cac043b1 +9be04068c3c0857a4cfd17d7e39e71d0423ebac2 +3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291 diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index decc0caeac3514..dedf591db38d01 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -236,7 +236,7 @@ case "$image" in TRITON=yes ;; pytorch-linux-focal-py3-clang10-onnx) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=10 PROTOBUF=yes DB=yes @@ -245,7 +245,7 @@ case "$image" in ONNX=yes ;; pytorch-linux-focal-py3-clang9-android-ndk-r21e) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=9 LLVMDEV=yes PROTOBUF=yes @@ -254,8 +254,8 @@ case "$image" in GRADLE_VERSION=6.8.3 NINJA_VERSION=1.9.0 ;; - pytorch-linux-focal-py3.8-clang10) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-focal-py3.9-clang10) + ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=10 PROTOBUF=yes DB=yes @@ -276,8 +276,8 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-focal-py3.8-gcc9) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-focal-py3.9-gcc9) + ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -286,7 +286,7 @@ case "$image" in TRITON=yes ;; pytorch-linux-focal-rocm-n-1-py3) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -297,7 +297,7 @@ case "$image" in TRITON=yes ;; pytorch-linux-focal-rocm-n-py3) - ANACONDA_PYTHON_VERSION=3.8 + ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=9 PROTOBUF=yes DB=yes @@ -318,8 +318,8 @@ case "$image" in CONDA_CMAKE=yes TRITON=yes ;; - pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 PROTOBUF=yes DB=yes @@ -330,8 +330,8 @@ case "$image" in DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12) + ANACONDA_PYTHON_VERSION=3.9 CUDA_VERSION=11.8 CUDNN_VERSION=9 CLANG_VERSION=12 @@ -355,8 +355,8 @@ case "$image" in CONDA_CMAKE=yes VISION=yes ;; - pytorch-linux-jammy-py3.8-gcc11) - ANACONDA_PYTHON_VERSION=3.8 + pytorch-linux-jammy-py3.9-gcc11) + ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 PROTOBUF=yes DB=yes diff --git a/.ci/docker/centos-rocm/Dockerfile b/.ci/docker/centos-rocm/Dockerfile index bfac9ddd859084..30ce1406e3f81c 100644 --- a/.ci/docker/centos-rocm/Dockerfile +++ b/.ci/docker/centos-rocm/Dockerfile @@ -108,10 +108,10 @@ ENV CMAKE_C_COMPILER cc ENV CMAKE_CXX_COMPILER c++ COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt +COPY ci_commit_pins/triton.txt triton.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install AOTriton (Early fail) COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index f6867de3907803..639bfcdd0420e2 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -5e9bab8c5956249e75a0f187bf8075df97ca2555 +cd1c833b079adb324871dcbbe75b43d42ffc0ade diff --git a/.ci/docker/ci_commit_pins/halide.txt b/.ci/docker/ci_commit_pins/halide.txt index 17ff17c31524e6..aaf88275edad42 100644 --- a/.ci/docker/ci_commit_pins/halide.txt +++ b/.ci/docker/ci_commit_pins/halide.txt @@ -1 +1 @@ -340136fec6d3ebc73e7a19eba1663e9b0ba8ab2d \ No newline at end of file +461c12871f336fe6f57b55d6a297f13ef209161b \ No newline at end of file diff --git a/.ci/docker/ci_commit_pins/triton-rocm.txt b/.ci/docker/ci_commit_pins/triton-rocm.txt deleted file mode 100644 index 0cb336acccb5b1..00000000000000 --- a/.ci/docker/ci_commit_pins/triton-rocm.txt +++ /dev/null @@ -1 +0,0 @@ -21eae954efa5bf584da70324b640288c3ee7aede diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index 3e47e877deb5fd..e25f89a42f2e6f 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -1b2f15840e0d70eec50d84c7a0575cb835524def +91b14bf5593cf58a8541f3e6b9125600a867d4ef diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 41c8d1602b6bef..7148e9c99cec22 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -dedb7bdf339a3546896d4820366ca562c586bfa0 +5fe38ffd73c2ac6ed6323b554205186696631c6f diff --git a/.ci/docker/common/install_aotriton.sh b/.ci/docker/common/install_aotriton.sh index 8b340ee219de2a..2aee95c48d4794 100755 --- a/.ci/docker/common/install_aotriton.sh +++ b/.ci/docker/common/install_aotriton.sh @@ -4,12 +4,12 @@ set -ex source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" -TARBALL='aotriton.tar.bz2' +TARBALL='aotriton.tar.gz' # This read command alwasy returns with exit code 1 read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true ARCH=$(uname -m) AOTRITON_INSTALL_PREFIX="$1" -AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.bz2" +AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz" cd "${AOTRITON_INSTALL_PREFIX}" # Must use -L to follow redirects diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 4d8a7b92e34ec6..f3c198044ffed0 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -5,32 +5,22 @@ set -ex # Optionally install conda if [ -n "$ANACONDA_PYTHON_VERSION" ]; then BASE_URL="https://repo.anaconda.com/miniconda" + CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" + if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then + BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" + CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" + fi MAJOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 1) MINOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 2) -if [[ $(uname -m) == "aarch64" ]]; then - BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" - case "$MAJOR_PYTHON_VERSION" in - 3) - CONDA_FILE="Miniforge3-Linux-aarch64.sh" - ;; - *) - echo "Unsupported ANACONDA_PYTHON_VERSION: $ANACONDA_PYTHON_VERSION" - exit 1 - ;; - esac -else case "$MAJOR_PYTHON_VERSION" in - 3) - CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" - ;; + 3);; *) echo "Unsupported ANACONDA_PYTHON_VERSION: $ANACONDA_PYTHON_VERSION" exit 1 ;; esac -fi mkdir -p /opt/conda chown jenkins:jenkins /opt/conda diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index 1bd25fb2de91f0..caf0467c523f97 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -7,7 +7,7 @@ PYTHON_DOWNLOAD_GITHUB_BRANCH=https://github.com/python/cpython/archive/refs/hea GET_PIP_URL=https://bootstrap.pypa.io/get-pip.py # Python versions to be installed in /opt/$VERSION_NO -CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.8.1 3.9.0 3.10.1 3.11.0 3.12.0 3.13.0"} +CPYTHON_VERSIONS=${CPYTHON_VERSIONS:-"3.8.1 3.9.0 3.10.1 3.11.0 3.12.0 3.13.0 3.13.0t"} function check_var { if [ -z "$1" ]; then @@ -22,6 +22,13 @@ function do_cpython_build { check_var $py_ver check_var $py_folder tar -xzf Python-$py_ver.tgz + + local additional_flags="" + if [ "$py_ver" == "3.13.0t" ]; then + additional_flags=" --disable-gil" + mv cpython-3.13/ cpython-3.13t/ + fi + pushd $py_folder local prefix="/opt/_internal/cpython-${py_ver}" @@ -37,8 +44,10 @@ function do_cpython_build { local openssl_flags="--with-openssl=${WITH_OPENSSL} --with-openssl-rpath=auto" fi + + # -Wformat added for https://bugs.python.org/issue17547 on Python 2.6 - CFLAGS="-Wformat" ./configure --prefix=${prefix} ${openssl_flags} ${shared_flags} > /dev/null + CFLAGS="-Wformat" ./configure --prefix=${prefix} ${openssl_flags} ${shared_flags} ${additional_flags} > /dev/null make -j40 > /dev/null make install > /dev/null @@ -58,7 +67,8 @@ function do_cpython_build { if [ -e ${prefix}/bin/pip3 ] && [ ! -e ${prefix}/bin/pip ]; then ln -s pip3 ${prefix}/bin/pip fi - ${prefix}/bin/pip install wheel==0.34.2 + # install setuptools since python 3.12 is required to use distutils + ${prefix}/bin/pip install wheel==0.34.2 setuptools==68.2.2 local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))") ln -s ${prefix} /opt/python/${abi_tag} } @@ -68,7 +78,14 @@ function build_cpython { check_var $py_ver check_var $PYTHON_DOWNLOAD_URL local py_ver_folder=$py_ver - if [ "$py_ver" = "3.13.0" ]; then + + if [ "$py_ver" = "3.13.0t" ]; then + PY_VER_SHORT="3.13" + PYT_VER_SHORT="3.13t" + check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH + wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz + do_cpython_build $py_ver cpython-$PYT_VER_SHORT + elif [ "$py_ver" = "3.13.0" ]; then PY_VER_SHORT="3.13" check_var $PYTHON_DOWNLOAD_GITHUB_BRANCH wget $PYTHON_DOWNLOAD_GITHUB_BRANCH/$PY_VER_SHORT.tar.gz -O Python-$py_ver.tgz diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index b77684bdf9a05c..c4b3f3e02a7838 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -5,7 +5,7 @@ set -ex # cuSPARSELt license: https://docs.nvidia.com/cuda/cusparselt/license.html mkdir tmp_cusparselt && cd tmp_cusparselt -if [[ ${CUDA_VERSION:0:4} =~ ^12\.[2-4]$ ]]; then +if [[ ${CUDA_VERSION:0:4} =~ ^12\.[2-6]$ ]]; then arch_path='sbsa' export TARGETARCH=${TARGETARCH:-$(uname -m)} if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 1d384233163d8d..b99ecac283d646 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -15,7 +15,7 @@ pip_install \ flatbuffers==2.0 \ mock==5.0.1 \ ninja==1.10.2 \ - networkx==2.0 \ + networkx==2.5 \ numpy==1.24.2 # ONNXRuntime should be installed before installing @@ -30,10 +30,9 @@ pip_install \ pip_install coloredlogs packaging -pip_install onnxruntime==1.18 -pip_install onnx==1.16.0 -# pip_install "onnxscript@git+https://github.com/microsoft/onnxscript@3e869ef8ccf19b5ebd21c10d3e9c267c9a9fa729" --no-deps -pip_install onnxscript==0.1.0.dev20240613 --no-deps +pip_install onnxruntime==1.18.1 +pip_install onnx==1.16.2 +pip_install onnxscript==0.1.0.dev20240831 --no-deps # required by onnxscript pip_install ml_dtypes diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index d4a0dc80ee177e..6e5fb8839c686e 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -12,10 +12,7 @@ conda_reinstall() { as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $* } -if [ -n "${ROCM_VERSION}" ]; then - TRITON_REPO="https://github.com/openai/triton" - TRITON_TEXT_FILE="triton-rocm" -elif [ -n "${XPU_VERSION}" ]; then +if [ -n "${XPU_VERSION}" ]; then TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" TRITON_TEXT_FILE="triton-xpu" else diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 6e0445b1a74a28..bd48d696c9230b 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -30,9 +30,14 @@ dill==0.3.7 #Pinned versions: 0.3.7 #test that import: dynamo/test_replay_record.py test_dataloader.py test_datapipe.py test_serialization.py -expecttest==0.1.6 +expecttest==0.2.1 #Description: method for writing tests where test framework auto populates # the expected output based on previous runs +#Pinned versions: 0.2.1 +#test that import: + +fbscribelogger==0.1.6 +#Description: write to scribe from authenticated jobs on CI #Pinned versions: 0.1.6 #test that import: @@ -85,7 +90,7 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==1.10.0 +mypy==1.11.2 # Pin MyPy version because new errors are likely to appear with each release #Description: linter #Pinned versions: 1.10.0 @@ -104,7 +109,7 @@ networkx==2.8.8 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py numba==0.49.0 ; python_version < "3.9" -numba==0.54.1 ; python_version == "3.9" +numba==0.55.2 ; python_version == "3.9" numba==0.55.2 ; python_version == "3.10" #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 @@ -332,3 +337,8 @@ onnxscript==0.1.0.dev20240817 #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: + +parameterized==0.8.1 +#Description: Parameterizes unittests, both the tests themselves and the entire testing class +#Pinned versions: +#test that import: diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 4a36342fcab700..fd2a01863fdd30 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.0.0 +3.1.0 diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index ee9ede8ba611b6..07e25f533a71f9 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -100,10 +100,10 @@ ARG TRITON # try to reach out to S3, which docker build runners don't have access COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh -COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt +COPY ci_commit_pins/triton.txt triton.txt COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install AOTriton COPY ./aotriton_version.txt aotriton_version.txt diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index 02cd1133a050c3..41f690c4ab386c 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -30,6 +30,7 @@ RUN bash ./install_docs_reqs.sh && rm install_docs_reqs.sh ARG ANACONDA_PYTHON_VERSION ARG CONDA_CMAKE ARG DOCS +ARG BUILD_ENVIRONMENT ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION ENV PATH /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/bin:/opt/conda/bin:$PATH ENV DOCS=$DOCS diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 9c2b4096d169de..a9662bcac2cefd 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -285,9 +285,8 @@ else if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* ]]; then if [[ "$BUILD_ENVIRONMENT" != *py3.8* ]]; then - # Install numpy-2.0 release candidate for builds - # Which should be backward compatible with Numpy-1.X - python -mpip install --pre numpy==2.0.0rc1 + # Install numpy-2.0.2 for builds which are backward compatible with 1.X + python -mpip install --pre numpy==2.0.2 fi WERROR=1 python setup.py clean diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index a54b8c360eba56..47707afdc2b14a 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -9,15 +9,13 @@ if [[ -n "$CONDA_ENV" ]]; then export PATH="$CONDA_ENV/bin":$PATH fi -# Test that OpenMP is enabled for non-arm64 build -if [[ ${BUILD_ENVIRONMENT} != *arm64* ]]; then - pushd test - if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available()))") == "1" ]]; then - echo "Build should have OpenMP enabled, but torch.backends.openmp.is_available() is False" - exit 1 - fi - popd +# Test that OpenMP is enabled +pushd test +if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available()))") == "1" ]]; then + echo "Build should have OpenMP enabled, but torch.backends.openmp.is_available() is False" + exit 1 fi +popd setup_test_python() { # The CircleCI worker hostname doesn't resolve to an address. @@ -27,8 +25,9 @@ setup_test_python() { echo "Ninja version: $(ninja --version)" echo "Python version: $(which python) ($(python --version))" - # Increase default limit on open file handles from 256 to 1024 - ulimit -n 1024 + # Set the limit on open file handles to 16384 + # might help with intermittent compiler test failures + ulimit -n 16384 } test_python_all() { diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index daf71c306849ed..b866ee8162e0ea 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -401,9 +401,9 @@ pr_time_benchmarks() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" - PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" "benchmarks/dynamo/pr_time_benchmarks/benchmarks" + PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" "benchmarks/dynamo/pr_time_benchmarks/benchmarks" echo "benchmark results on current PR: " - cat "$TEST_REPORTS_DIR/pr_time_benchmarks_after.txt" + cat "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" } @@ -575,10 +575,10 @@ test_single_dynamo_benchmark() { fi if [[ "${TEST_CONFIG}" == *_avx2* ]]; then - TEST_CONFIG=${TEST_CONFIG::-5} + TEST_CONFIG=${TEST_CONFIG//_avx2/} fi if [[ "${TEST_CONFIG}" == *_avx512* ]]; then - TEST_CONFIG=${TEST_CONFIG::-7} + TEST_CONFIG=${TEST_CONFIG//_avx512/} fi python "benchmarks/dynamo/$suite.py" \ --ci --accuracy --timing --explain \ @@ -596,6 +596,9 @@ test_single_dynamo_benchmark() { test_inductor_micro_benchmark() { TEST_REPORTS_DIR=$(pwd)/test/test-reports + if [[ "${TEST_CONFIG}" == *cpu* ]]; then + test_inductor_set_cpu_affinity + fi python benchmarks/gpt_fast/benchmark.py --output "${TEST_REPORTS_DIR}/gpt_fast_benchmark.csv" } @@ -1380,14 +1383,16 @@ test_executorch() { assert_git_not_dirty } -test_linux_aarch64(){ +test_linux_aarch64() { python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \ - test_transformers test_multiprocessing test_numpy_interop --verbose + test_transformers test_multiprocessing test_numpy_interop \ + --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Dynamo tests python test/run_test.py --include dynamo/test_compile dynamo/test_backends dynamo/test_comptime dynamo/test_config \ dynamo/test_functions dynamo/test_fx_passes_pre_grad dynamo/test_interop dynamo/test_model_output dynamo/test_modules \ - dynamo/test_optimizers dynamo/test_recompile_ux dynamo/test_recompiles --verbose + dynamo/test_optimizers dynamo/test_recompile_ux dynamo/test_recompiles \ + --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose # Inductor tests python test/run_test.py --include inductor/test_torchinductor inductor/test_benchmark_fusion inductor/test_codecache \ @@ -1397,7 +1402,8 @@ test_linux_aarch64(){ inductor/test_max_autotune inductor/test_memory_planning inductor/test_metrics inductor/test_multi_kernel inductor/test_pad_mm \ inductor/test_pattern_matcher inductor/test_perf inductor/test_profiler inductor/test_select_algorithm inductor/test_smoke \ inductor/test_split_cat_fx_passes inductor/test_standalone_compile inductor/test_torchinductor \ - inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes --verbose + inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes \ + --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose } if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then @@ -1479,9 +1485,7 @@ elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision test_inductor_shard "${SHARD_NUMBER}" if [[ "${SHARD_NUMBER}" == 1 ]]; then - if [[ "${BUILD_ENVIRONMENT}" != linux-jammy-py3.8-gcc11-build ]]; then - # Temporarily skip test_inductor_aoti due to https://github.com/pytorch/pytorch/issues/130311 - test_inductor_aoti + if [[ "${BUILD_ENVIRONMENT}" != linux-jammy-py3.9-gcc11-build ]]; then test_inductor_distributed fi fi diff --git a/.ci/pytorch/win-test-helpers/build_pytorch.bat b/.ci/pytorch/win-test-helpers/build_pytorch.bat index 824acc09c5854a..92078d22326396 100644 --- a/.ci/pytorch/win-test-helpers/build_pytorch.bat +++ b/.ci/pytorch/win-test-helpers/build_pytorch.bat @@ -24,6 +24,12 @@ call %INSTALLER_DIR%\install_sccache.bat if errorlevel 1 goto fail if not errorlevel 0 goto fail +if "%USE_XPU%"=="1" ( + :: Install xpu support packages + call %INSTALLER_DIR%\install_xpu.bat + if errorlevel 1 exit /b 1 +) + :: Miniconda has been installed as part of the Windows AMI with all the dependencies. :: We just need to activate it here call %INSTALLER_DIR%\activate_miniconda3.bat @@ -43,6 +49,16 @@ if "%VC_VERSION%" == "" ( ) if errorlevel 1 goto fail if not errorlevel 0 goto fail + +if "%USE_XPU%"=="1" ( + :: Activate xpu environment - VS env is required for xpu + call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" + if errorlevel 1 exit /b 1 + :: Reduce build time. Only have MTL self-hosted runner now + SET TORCH_XPU_ARCH_LIST=xe-lpg + SET USE_KINETO=0 +) + @echo on popd diff --git a/.ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat b/.ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat new file mode 100644 index 00000000000000..b9fd597929dded --- /dev/null +++ b/.ci/pytorch/win-test-helpers/installation-helpers/install_xpu.bat @@ -0,0 +1,91 @@ +@echo on +REM Description: Install Intel Support Packages on Windows +REM BKM reference: https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html + +set XPU_INSTALL_MODE=%~1 +if "%XPU_INSTALL_MODE%"=="" goto xpu_bundle_install_start +if "%XPU_INSTALL_MODE%"=="bundle" goto xpu_bundle_install_start +if "%XPU_INSTALL_MODE%"=="driver" goto xpu_driver_install_start +if "%XPU_INSTALL_MODE%"=="all" goto xpu_driver_install_start + +:arg_error + +echo Illegal XPU installation mode. The value can be "bundle"/"driver"/"all" +echo If keep the value as space, will use default "bundle" mode +exit /b 1 + +:xpu_driver_install_start +:: TODO Need more testing for driver installation +set XPU_DRIVER_LINK=https://downloadmirror.intel.com/830975/gfx_win_101.5972.exe +curl -o xpu_driver.exe --retry 3 --retry-all-errors -k %XPU_DRIVER_LINK% +echo "XPU Driver installing..." +start /wait "Intel XPU Driver Installer" "xpu_driver.exe" +if errorlevel 1 exit /b 1 +del xpu_driver.exe +if "%XPU_INSTALL_MODE%"=="driver" goto xpu_install_end + +:xpu_bundle_install_start + +set XPU_BUNDLE_PARENT_DIR=C:\Program Files (x86)\Intel\oneAPI +set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-for-pytorch-gpu-dev_p_0.5.3.37_offline.exe +set XPU_PTI_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-pti-dev_p_0.9.0.37_offline.exe +set XPU_BUNDLE_VERSION=0.5.3+31 +set XPU_PTI_VERSION=0.9.0+36 +set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.intel-for-pytorch-gpu-dev.product +set XPU_PTI_PRODUCT_NAME=intel.oneapi.win.intel-pti-dev.product +set XPU_BUNDLE_INSTALLED=0 +set XPU_PTI_INSTALLED=0 +set XPU_BUNDLE_UNINSTALL=0 +set XPU_PTI_UNINSTALL=0 + +:: Check if XPU bundle is target version or already installed +if exist "%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" goto xpu_bundle_ver_check +goto xpu_bundle_install + +:xpu_bundle_ver_check + +"%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" --list-products > xpu_bundle_installed_ver.log + +for /f "tokens=1,2" %%a in (xpu_bundle_installed_ver.log) do ( + if "%%a"=="%XPU_BUNDLE_PRODUCT_NAME%" ( + echo %%a Installed Version: %%b + set XPU_BUNDLE_INSTALLED=1 + if not "%XPU_BUNDLE_VERSION%"=="%%b" ( + start /wait "Installer Title" "%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" --action=remove --eula=accept --silent --product-id %XPU_BUNDLE_PRODUCT_NAME% --product-ver %%b --log-dir uninstall_bundle + set XPU_BUNDLE_UNINSTALL=1 + ) + ) + if "%%a"=="%XPU_PTI_PRODUCT_NAME%" ( + echo %%a Installed Version: %%b + set XPU_PTI_INSTALLED=1 + if not "%XPU_PTI_VERSION%"=="%%b" ( + start /wait "Installer Title" "%XPU_BUNDLE_PARENT_DIR%\Installer\installer.exe" --action=remove --eula=accept --silent --product-id %XPU_PTI_PRODUCT_NAME% --product-ver %%b --log-dir uninstall_bundle + set XPU_PTI_UNINSTALL=1 + ) + ) +) +if errorlevel 1 exit /b 1 +if exist xpu_bundle_installed_ver.log del xpu_bundle_installed_ver.log +if "%XPU_BUNDLE_INSTALLED%"=="0" goto xpu_bundle_install +if "%XPU_BUNDLE_UNINSTALL%"=="1" goto xpu_bundle_install +if "%XPU_PTI_INSTALLED%"=="0" goto xpu_pti_install +if "%XPU_PTI_UNINSTALL%"=="1" goto xpu_pti_install +goto xpu_install_end + +:xpu_bundle_install + +curl -o xpu_bundle.exe --retry 3 --retry-all-errors -k %XPU_BUNDLE_URL% +echo "XPU Bundle installing..." +start /wait "Intel Pytorch Bundle Installer" "xpu_bundle.exe" --action=install --eula=accept --silent --log-dir install_bundle +if errorlevel 1 exit /b 1 +del xpu_bundle.exe + +:xpu_pti_install + +curl -o xpu_pti.exe --retry 3 --retry-all-errors -k %XPU_PTI_URL% +echo "XPU PTI installing..." +start /wait "Intel PTI Installer" "xpu_pti.exe" --action=install --eula=accept --silent --log-dir install_bundle +if errorlevel 1 exit /b 1 +del xpu_pti.exe + +:xpu_install_end diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index 39fcb65132f676..09b624183c7ae7 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -40,6 +40,12 @@ python -m pip install pytest-rerunfailures==10.3 pytest-cpp==2.3.0 tensorboard== # Install Z3 optional dependency for Windows builds. python -m pip install z3-solver==4.12.2.0 +# Install tlparse for test\dynamo\test_structured_trace.py UTs. +python -m pip install tlparse==0.3.25 + +# Install parameterized +python -m pip install parameterized==0.8.1 + run_tests() { # Run nvidia-smi if available for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 5d92c9099bff99..81d7bf2c511a05 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -116,15 +116,14 @@ if [[ "$PACKAGE_TYPE" == libtorch ]]; then cd /tmp/libtorch fi -if [[ "$GPU_ARCH_TYPE" == xpu ]]; then - # Refer https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html - source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh - source /opt/intel/oneapi/pti/latest/env/vars.sh -fi - # Test the package /builder/check_binary.sh +if [[ "\$GPU_ARCH_TYPE" != *s390x* && "\$GPU_ARCH_TYPE" != *xpu* && "\$GPU_ARCH_TYPE" != *rocm* && "$PACKAGE_TYPE" != libtorch ]]; then + # Exclude s390, xpu, rocm and libtorch builds from smoke testing + python /builder/test/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled +fi + # Clean temp files cd /builder && git clean -ffdx diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index e918635922a45e..106d0917ca68c5 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -90,7 +90,7 @@ fi if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*rocm.* && $(uname) == "Linux" ]]; then TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}; ${TRITON_CONSTRAINT}" if [[ -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_BUILD_VERSION" =~ .*dev.* ]]; then - TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton-rocm.txt) + TRITON_SHORTHASH=$(cut -c1-10 $PYTORCH_ROOT/.ci/docker/ci_commit_pins/triton.txt) TRITON_REQUIREMENT="pytorch-triton-rocm==${TRITON_VERSION}+${TRITON_SHORTHASH}; ${TRITON_CONSTRAINT}" fi if [[ -z "${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:-}" ]]; then diff --git a/.circleci/scripts/binary_windows_build.sh b/.circleci/scripts/binary_windows_build.sh index 0f8eeb8cebe05d..d62ce7d77efe44 100644 --- a/.circleci/scripts/binary_windows_build.sh +++ b/.circleci/scripts/binary_windows_build.sh @@ -10,6 +10,11 @@ export SCCACHE_BUCKET=ossci-compiler-cache export SCCACHE_IGNORE_SERVER_IO_ERROR=1 export VC_YEAR=2019 +if [[ "$DESIRED_CUDA" == 'xpu' ]]; then + export VC_YEAR=2022 + export USE_SCCACHE=0 +fi + echo "Free space on filesystem before build:" df -h diff --git a/.circleci/scripts/binary_windows_test.sh b/.circleci/scripts/binary_windows_test.sh index bbf0efbb5e52ff..b9f801fb0c50f4 100644 --- a/.circleci/scripts/binary_windows_test.sh +++ b/.circleci/scripts/binary_windows_test.sh @@ -6,6 +6,10 @@ source "${BINARY_ENV_FILE:-/c/w/env}" export CUDA_VERSION="${DESIRED_CUDA/cu/}" export VC_YEAR=2019 +if [[ "$DESIRED_CUDA" == 'xpu' ]]; then + export VC_YEAR=2022 +fi + pushd "$BUILDER_ROOT" ./windows/internal/smoke_test.bat diff --git a/.flake8 b/.flake8 index c789819f477bae..4e1cb4642d418f 100644 --- a/.flake8 +++ b/.flake8 @@ -57,7 +57,7 @@ per-file-ignores = torch/distributed/_tensor/_collective_utils.py: TOR901 # This is a full package that happen to live within the test # folder, so ok to skip - test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901 + test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py: TOR901 optional-ascii-coding = True exclude = ./.git, diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 3408ca3b1c3b88..bc83f0c32ee785 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -3,8 +3,6 @@ self-hosted-runner: # GitHub hosted x86 Linux runners - linux.20_04.4x - linux.20_04.16x - # Repo-specific LF hosted ARC runners - - linux.large.arc # Organization-wide AWS Linux Runners - linux.large - linux.2xlarge @@ -16,7 +14,9 @@ self-hosted-runner: - linux.24xlarge - linux.24xlarge.ephemeral - linux.arm64.2xlarge + - linux.arm64.2xlarge.ephemeral - linux.arm64.m7g.4xlarge + - linux.arm64.m7g.4xlarge.ephemeral - linux.4xlarge.nvidia.gpu - linux.8xlarge.nvidia.gpu - linux.16xlarge.nvidia.gpu @@ -40,6 +40,7 @@ self-hosted-runner: - amz2023.linux.24xlarge - amz2023.linux.arm64.2xlarge - amz2023.linux.arm64.m7g.4xlarge + - amz2023.linux.arm64.m7g.4xlarge.ephemeral - amz2023.linux.4xlarge.nvidia.gpu - amz2023.linux.8xlarge.nvidia.gpu - amz2023.linux.16xlarge.nvidia.gpu @@ -60,6 +61,7 @@ self-hosted-runner: # Organization wide AWS Windows runners - windows.g4dn.xlarge - windows.g4dn.xlarge.nonephemeral + - windows.4xlarge - windows.4xlarge.nonephemeral - windows.8xlarge.nvidia.gpu - windows.8xlarge.nvidia.gpu.nonephemeral diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 3973bc933bf981..c835e7c2838715 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -b3f6f511f2a1082bd56b13a3f6794e7fc3ba4862 +ba696ea3dfec4cbe693bf06a84c75dc196077f5b diff --git a/.github/label_to_label.yml b/.github/label_to_label.yml index e6c66a5e56cf68..5d6544a2f50f0a 100644 --- a/.github/label_to_label.yml +++ b/.github/label_to_label.yml @@ -1,13 +1,50 @@ # Use this to auto apply labels based on other labels. Applies to both PRs and # issues. Currently only supports any and all - any: - - "module: custom operators" + - "module: opcheck" + then: + - "module: custom-operators" +- any: + - "module: custom-operators" + - "module: functionalization" - "module: aotdispatch" + - "module: higher order operators" + - "module: fakeTensor" + - "module: ProxyTensor" + - "module: library" + - "module: reinplacing" then: - "module: pt2-dispatcher" +- any: + - "module: vmap" + then: + - "module: functorch" +- any: + - "module: reinplacing" + then: + - "module: inductor" +- any: + - "module: pt2 optimizer" + then: + - "module: dynamo" +- any: + - "module: flex attention" + then: + - "module: higher order operators" +- any: + - "module: aotinductor" + then: + - "oncall: export" - any: - "module: dynamo" - "module: pt2-dispatcher" - "module: inductor" + - "module: aotinductor" + - "module: cudagraphs" + - "oncall: export" + - "module: startup-tracing-compile" + - "module: compiled autograd" + - "module: flex attention" + - "module: dynamic shapes" then: - "oncall: pt2" diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml index aaa0e21c92ef12..482b55e04423e9 100644 --- a/.github/lf-canary-scale-config.yml +++ b/.github/lf-canary-scale-config.yml @@ -7,10 +7,14 @@ # runners. Runners listed here will be available as self hosted # runners, configuration is directly pulled from the main branch. # -# NOTE (Apr, 5, 2021): Linux runners are currently all an amazonlinux2 # -# NOTE (Jan 5, 2021): Linux runners are all non-ephemeral to reduce the amount of CreateInstaces calls -# to avoid RequestLimitExceeded issues +# NOTES: +# - Linux runners are by default non-ephemeral to reduce the amount of CreateInstaces calls +# to avoid RequestLimitExceeded issues +# - When updating this file, run the following command to validate the YAML and to generate +# corresponding versions of scale-config for the pytorch/pytorch repo and merge the +# pytorch/pytorch changes before merging these changes. +# `python .github/scripts/validate_scale_config.py --test-infra-repo-root [path_to_test-infra_root] --pytorch-repo-root [path_to_pytorch_root]`` # # TODO: Add some documentation on how the auto-scaling works # @@ -35,8 +39,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.10xlarge.avx2: disk_size: 200 instance_type: m4.10xlarge @@ -44,11 +46,6 @@ runner_types: max_available: 450 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.24xl.spr-metal: disk_size: 200 instance_type: c7i.metal-24xl @@ -56,11 +53,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.16xlarge.spr: disk_size: 200 instance_type: c7i.16xlarge @@ -68,11 +60,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.9xlarge.ephemeral: disk_size: 200 instance_type: c5.9xlarge @@ -81,8 +68,6 @@ runner_types: os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 am2: ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.12xlarge.ephemeral: @@ -92,11 +77,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.16xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.16xlarge @@ -104,11 +84,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.24xlarge: disk_size: 150 instance_type: c5.24xlarge @@ -116,11 +91,6 @@ runner_types: max_available: 500 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.24xlarge.ephemeral: disk_size: 150 instance_type: c5.24xlarge @@ -128,11 +98,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.2xlarge: disk_size: 150 instance_type: c5.2xlarge @@ -140,11 +105,6 @@ runner_types: max_available: 3120 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.4xlarge: disk_size: 150 instance_type: c5.4xlarge @@ -155,8 +115,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.4xlarge @@ -164,11 +122,6 @@ runner_types: max_available: 1000 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.8xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.8xlarge @@ -179,8 +132,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g4dn.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g4dn.12xlarge @@ -188,11 +139,6 @@ runner_types: max_available: 250 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g4dn.metal.nvidia.gpu: disk_size: 150 instance_type: g4dn.metal @@ -200,11 +146,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g5.48xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.48xlarge @@ -212,11 +153,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g5.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.12xlarge @@ -224,11 +160,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g5.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.4xlarge @@ -236,11 +167,6 @@ runner_types: max_available: 2400 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.g6.4xlarge.experimental.nvidia.gpu: disk_size: 150 instance_type: g6.4xlarge @@ -251,8 +177,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.large: max_available: 1200 disk_size: 15 @@ -260,11 +184,6 @@ runner_types: is_ephemeral: false os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.c.linux.arm64.2xlarge: disk_size: 256 instance_type: t4g.2xlarge @@ -272,11 +191,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.linux.arm64.m7g.4xlarge: disk_size: 256 instance_type: m7g.4xlarge @@ -284,11 +198,20 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 + lf.c.linux.arm64.2xlarge.ephemeral: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + lf.c.linux.arm64.m7g.4xlarge.ephemeral: + disk_size: 256 + instance_type: m7g.4xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 lf.c.linux.arm64.m7g.metal: disk_size: 256 instance_type: m7g.metal @@ -296,11 +219,6 @@ runner_types: max_available: 100 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.c.windows.g4dn.xlarge: disk_size: 256 instance_type: g4dn.xlarge diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml index 338dc4c534bb22..7c352157cb464a 100644 --- a/.github/lf-scale-config.yml +++ b/.github/lf-scale-config.yml @@ -7,10 +7,14 @@ # runners. Runners listed here will be available as self hosted # runners, configuration is directly pulled from the main branch. # -# NOTE (Apr, 5, 2021): Linux runners are currently all an amazonlinux2 # -# NOTE (Jan 5, 2021): Linux runners are all non-ephemeral to reduce the amount of CreateInstaces calls -# to avoid RequestLimitExceeded issues +# NOTES: +# - Linux runners are by default non-ephemeral to reduce the amount of CreateInstaces calls +# to avoid RequestLimitExceeded issues +# - When updating this file, run the following command to validate the YAML and to generate +# corresponding versions of scale-config for the pytorch/pytorch repo and merge the +# pytorch/pytorch changes before merging these changes. +# `python .github/scripts/validate_scale_config.py --test-infra-repo-root [path_to_test-infra_root] --pytorch-repo-root [path_to_pytorch_root]`` # # TODO: Add some documentation on how the auto-scaling works # @@ -35,8 +39,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.10xlarge.avx2: disk_size: 200 instance_type: m4.10xlarge @@ -44,11 +46,6 @@ runner_types: max_available: 450 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.24xl.spr-metal: disk_size: 200 instance_type: c7i.metal-24xl @@ -56,11 +53,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.16xlarge.spr: disk_size: 200 instance_type: c7i.16xlarge @@ -68,11 +60,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.9xlarge.ephemeral: disk_size: 200 instance_type: c5.9xlarge @@ -81,8 +68,6 @@ runner_types: os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 am2: ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.12xlarge.ephemeral: @@ -92,11 +77,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.16xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.16xlarge @@ -104,11 +84,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.24xlarge: disk_size: 150 instance_type: c5.24xlarge @@ -116,11 +91,6 @@ runner_types: max_available: 500 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.24xlarge.ephemeral: disk_size: 150 instance_type: c5.24xlarge @@ -128,11 +98,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.2xlarge: disk_size: 150 instance_type: c5.2xlarge @@ -140,11 +105,6 @@ runner_types: max_available: 3120 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.4xlarge: disk_size: 150 instance_type: c5.4xlarge @@ -155,8 +115,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.4xlarge @@ -164,11 +122,6 @@ runner_types: max_available: 1000 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.8xlarge.nvidia.gpu: disk_size: 150 instance_type: g3.8xlarge @@ -179,8 +132,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g4dn.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g4dn.12xlarge @@ -188,11 +139,6 @@ runner_types: max_available: 250 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g4dn.metal.nvidia.gpu: disk_size: 150 instance_type: g4dn.metal @@ -200,11 +146,6 @@ runner_types: max_available: 300 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g5.48xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.48xlarge @@ -212,11 +153,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g5.12xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.12xlarge @@ -224,11 +160,6 @@ runner_types: max_available: 150 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g5.4xlarge.nvidia.gpu: disk_size: 150 instance_type: g5.4xlarge @@ -236,11 +167,6 @@ runner_types: max_available: 2400 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.g6.4xlarge.experimental.nvidia.gpu: disk_size: 150 instance_type: g6.4xlarge @@ -251,8 +177,6 @@ runner_types: variants: amz2023: ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.large: max_available: 1200 disk_size: 15 @@ -260,11 +184,6 @@ runner_types: is_ephemeral: false os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-x86_64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs lf.linux.arm64.2xlarge: disk_size: 256 instance_type: t4g.2xlarge @@ -272,11 +191,6 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.linux.arm64.m7g.4xlarge: disk_size: 256 instance_type: m7g.4xlarge @@ -284,11 +198,20 @@ runner_types: max_available: 200 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 + lf.linux.arm64.2xlarge.ephemeral: + disk_size: 256 + instance_type: t4g.2xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 + lf.linux.arm64.m7g.4xlarge.ephemeral: + disk_size: 256 + instance_type: m7g.4xlarge + is_ephemeral: true + max_available: 200 + os: linux + ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 lf.linux.arm64.m7g.metal: disk_size: 256 instance_type: m7g.metal @@ -296,11 +219,6 @@ runner_types: max_available: 100 os: linux ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - variants: - amz2023: - ami: al2023-ami-2023.5.20240701.0-kernel-6.1-arm64 - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-arm64-gp2 lf.windows.g4dn.xlarge: disk_size: 256 instance_type: g4dn.xlarge diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 7d0976f2bdddc1..350010708be630 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -86,6 +86,18 @@ - pull - inductor +- name: OSS CI / pytorchbot / slow tests + patterns: + - test/slow_tests.json + approved_by: + - pytorchbot + ignore_flaky_failures: false + mandatory_checks_name: + - EasyCLA + - Lint + - pull + - slow + - name: OSS CI /pytorchbot / Executorch patterns: - .ci/docker/ci_commit_pins/executorch.txt @@ -107,8 +119,8 @@ mandatory_checks_name: - EasyCLA - Lint - - pull / linux-focal-py3_8-clang9-xla / build - - pull / linux-focal-py3_8-clang9-xla / test (xla, 1, 1, linux.12xlarge) + - pull / linux-focal-py3_9-clang9-xla / build + - pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge) - name: Documentation patterns: diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 991c139b89001f..e43c39396c5f0e 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -9,6 +9,7 @@ ciflow_push_tags: - ciflow/inductor-rocm - ciflow/inductor-perf-compare - ciflow/inductor-micro-benchmark +- ciflow/inductor-micro-benchmark-cpu-x86 - ciflow/inductor-cu124 - ciflow/linux-aarch64 - ciflow/mps diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index c3221044889c74..c72d1b568ca112 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -1,6 +1,7 @@ boto3==1.19.12 hypothesis==6.56.4 -expecttest==0.1.6 +expecttest==0.2.1 +fbscribelogger==0.1.6 librosa>=0.6.2 mpmath==1.3.0 networkx==2.8.7 @@ -30,3 +31,4 @@ optree==0.12.1 # NB: test_hparams_* from test_tensorboard is failing with protobuf 5.26.0 in # which the stringify metadata is wrong when escaping double quote protobuf==3.20.2 +parameterized==0.8.1 diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 7ee2cb4b8e61b5..096b20fe0909c0 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -15,9 +15,7 @@ def read_triton_pin(device: str = "cuda") -> str: triton_file = "triton.txt" - if device == "rocm": - triton_file = "triton-rocm.txt" - elif device == "xpu": + if device == "xpu": triton_file = "triton-xpu.txt" with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / triton_file) as f: return f.read().strip() diff --git a/.github/scripts/check_labels.py b/.github/scripts/check_labels.py index a29ef24a414c54..10be42c3fd5647 100755 --- a/.github/scripts/check_labels.py +++ b/.github/scripts/check_labels.py @@ -27,6 +27,12 @@ def parse_args() -> Any: parser = ArgumentParser("Check PR labels") parser.add_argument("pr_num", type=int) + # add a flag to return a non-zero exit code if the PR does not have the required labels + parser.add_argument( + "--exit-non-zero", + action="store_true", + help="Return a non-zero exit code if the PR does not have the required labels", + ) return parser.parse_args() @@ -41,10 +47,13 @@ def main() -> None: if not has_required_labels(pr): print(LABEL_ERR_MSG) add_label_err_comment(pr) + if args.exit_non_zero: + sys.exit(1) else: delete_all_label_err_comments(pr) except Exception as e: - pass + if args.exit_non_zero: + sys.exit(1) sys.exit(0) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 9dc8ebed3c09cd..993459bf320471 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -325,6 +325,7 @@ def generate_wheels_matrix( os: str, arches: Optional[List[str]] = None, python_versions: Optional[List[str]] = None, + use_split_build: bool = False, ) -> List[Dict[str, str]]: package_type = "wheel" if os == "linux" or os == "linux-aarch64" or os == "linux-s390x": @@ -340,7 +341,7 @@ def generate_wheels_matrix( if os == "linux": arches += CPU_CXX11_ABI_ARCH + CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES elif os == "windows": - arches += CUDA_ARCHES + arches += CUDA_ARCHES + XPU_ARCHES elif os == "linux-aarch64": # Only want the one arch as the CPU type is different and # uses different build/test scripts @@ -371,7 +372,17 @@ def generate_wheels_matrix( ) and python_version == "3.13": continue + if use_split_build and ( + arch_version not in ["12.4", "12.1", "11.8", "cpu"] or os != "linux" + ): + raise RuntimeError( + "Split build is only supported on linux with cuda 12.4, 12.1, 11.8, and cpu.\n" + f"Currently attempting to build on arch version {arch_version} and os {os}.\n" + "Please modify the matrix generation to exclude this combination." + ) + # 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install + if ( arch_version in ["12.4", "12.1", "11.8"] and os == "linux" @@ -385,6 +396,7 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), + "use_split_build": "True" if use_split_build else "False", "devtoolset": ( "cxx11-abi" if arch_version == "cuda-aarch64" else "" ), @@ -400,7 +412,8 @@ def generate_wheels_matrix( ), } ) - if arch_version != "cuda-aarch64": + # Special build building to use on Colab. Python 3.11 for 12.1 CUDA + if python_version == "3.11" and arch_version == "12.1": ret.append( { "python_version": python_version, @@ -409,40 +422,16 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), - "use_split_build": "True", + "use_split_build": "True" if use_split_build else "False", "devtoolset": "", "container_image": WHEEL_CONTAINER_IMAGES[arch_version], "package_type": package_type, - "pytorch_extra_install_requirements": ( - PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] # fmt: skip - if os != "linux-aarch64" - else "" - ), - "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-split".replace( # noqa: B950 + "pytorch_extra_install_requirements": "", + "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950 ".", "_" ), } ) - # Special build building to use on Colab. PyThon 3.10 for 12.1 CUDA - if python_version == "3.10" and arch_version == "12.1": - ret.append( - { - "python_version": python_version, - "gpu_arch_type": gpu_arch_type, - "gpu_arch_version": gpu_arch_version, - "desired_cuda": translate_desired_cuda( - gpu_arch_type, gpu_arch_version - ), - "use_split_build": "False", - "devtoolset": "", - "container_image": WHEEL_CONTAINER_IMAGES[arch_version], - "package_type": package_type, - "pytorch_extra_install_requirements": "", - "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-full".replace( # noqa: B950 - ".", "_" - ), - } - ) else: ret.append( { @@ -452,6 +441,7 @@ def generate_wheels_matrix( "desired_cuda": translate_desired_cuda( gpu_arch_type, gpu_arch_version ), + "use_split_build": "True" if use_split_build else "False", "devtoolset": ( "cxx11-abi" if arch_version == "cpu-cxx11-abi" else "" ), @@ -462,11 +452,12 @@ def generate_wheels_matrix( ), "pytorch_extra_install_requirements": ( PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"] # fmt: skip - if os != "linux" + if os != "linux" and gpu_arch_type != "xpu" else "" ), } ) + return ret diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 8a66ebe5f59b5f..f9c857a3ed9cb9 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -61,6 +61,7 @@ class BinaryBuildWorkflow: # Mainly for macos cross_compile_arm64: bool = False macos_runner: str = "macos-14-xlarge" + use_split_build: bool = False def __post_init__(self) -> None: if self.abi_version: @@ -69,6 +70,9 @@ def __post_init__(self) -> None: ) else: self.build_environment = f"{self.os}-binary-{self.package_type}" + if self.use_split_build: + # added to distinguish concurrency groups + self.build_environment += "-split" def generate_workflow_file(self, workflow_template: jinja2.Template) -> None: output_file_path = ( @@ -110,6 +114,20 @@ class OperatingSystem: isolated_workflow=True, ), ), + BinaryBuildWorkflow( + os=OperatingSystem.LINUX, + package_type="manywheel", + build_configs=generate_binary_build_matrix.generate_wheels_matrix( + OperatingSystem.LINUX, + use_split_build=True, + arches=["11.8", "12.1", "12.4", "cpu"], + ), + ciflow_config=CIFlowConfig( + labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, + isolated_workflow=True, + ), + use_split_build=True, + ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="conda", @@ -162,6 +180,21 @@ class OperatingSystem: ), branches="main", ), + BinaryBuildWorkflow( + os=OperatingSystem.LINUX, + package_type="manywheel", + build_configs=generate_binary_build_matrix.generate_wheels_matrix( + OperatingSystem.LINUX, + arches=["11.8", "12.1", "12.4"], + python_versions=["3.9"], + use_split_build=True, + ), + ciflow_config=CIFlowConfig( + labels={LABEL_CIFLOW_PERIODIC}, + ), + branches="main", + use_split_build=True, + ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index 5acef33903ba52..a5206bd675fe6d 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -46,16 +46,24 @@ def gh_fetch_url_and_headers( with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn: return conn.headers, reader(conn) except HTTPError as err: - if err.code == 403 and all( - key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"] + if ( + err.code == 403 + and all( + key in err.headers + for key in ["X-RateLimit-Limit", "X-RateLimit-Remaining"] + ) + and int(err.headers["X-RateLimit-Remaining"]) == 0 ): print( - f"""Rate limit exceeded: + f"""{url} + Rate limit exceeded: Used: {err.headers['X-RateLimit-Used']} Limit: {err.headers['X-RateLimit-Limit']} Remaining: {err.headers['X-RateLimit-Remaining']} Resets at: {err.headers['x-RateLimit-Reset']}""" ) + else: + print(f"Error fetching {url} {err}") raise @@ -160,6 +168,14 @@ def gh_post_commit_comment( ) +def gh_close_pr(org: str, repo: str, pr_num: int, dry_run: bool = False) -> None: + url = f"{GITHUB_API_URL}/repos/{org}/{repo}/pulls/{pr_num}" + if dry_run: + print(f"Dry run closing PR {pr_num}") + else: + gh_fetch_url(url, method="PATCH", data={"state": "closed"}) + + def gh_delete_comment(org: str, repo: str, comment_id: int) -> None: url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}" gh_fetch_url(url, method="DELETE") diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index 7aa71e7c688595..641e438b784510 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -3,49 +3,94 @@ """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default -https://github.com/pytorch/test-infra/issues/5132) as a user list to determine -which users will get their jobs to run on experimental runners. This user list -is also a comma separated list of additional features or experiments which the -user could be opted in to. +https://github.com/pytorch/test-infra/issues/5132) to define the configuration +of which runners should be used to run which job. + +The configuration has two parts, the settings and a list of opted-in users, +separated by a line containing "---". If the line is not present, the +settings are considered to be empty with only the second part, the user +list, defined. + +The first part is a YAML block that defines the rollout settings. This can be +used to define any settings that are needed to determine which runners to use. +It's fields are defined by the RolloutSettings class below. + +The second part is a list of users who are explicitly opted in to the LF fleet. +The user list is also a comma separated list of additional features or +experiments which the user could be opted in to. The user list has the following rules: -- Users are GitHub usernames with the @ prefix -- If the first line is a "*" then all users will use the new runners -- If the first line is a "!" then all users will use the old runners +- Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable -- A "#" prefix indicates the user is opted out of the new runners but is opting - into features/experiments. +- A "#" prefix opts the user out of all experiments + +Example config: + # A list of experiments that can be opted into. + # This defines the behavior they'll induce when opted into. + # Expected syntax is: + # [experiment_name]: # Name of the experiment. Also used for the label prefix. + # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. + + experiments: + lf: + rollout_percent: 25 + + --- -Example user list: + # Opt-ins: + # Users can opt into the LF fleet by adding their GitHub username to this list + # and specifying experiments to enable in a comma-separated list. + # Experiments should be from the above list. - @User1 - @User2,amz2023 - #@UserOptOutOfNewRunner,amz2023 + @User1,lf,split_build + @User2,lf + @User3,split_build """ import logging import os +import random from argparse import ArgumentParser from logging import LogRecord -from typing import Any, Iterable +from typing import Any, Dict, Iterable, List, NamedTuple, Tuple +import yaml from github import Auth, Github from github.Issue import Issue -WORKFLOW_LABEL_META = "" # use meta runners +DEFAULT_LABEL_PREFIX = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation -RUNNER_AMI_LEGACY = "" -RUNNER_AMI_AMZ2023 = "amz2023" - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" +SETTING_EXPERIMENTS = "experiments" + +LF_FLEET_EXPERIMENT = "lf" +CANARY_FLEET_SUFFIX = ".c" + + +class Experiment(NamedTuple): + rollout_perc: float = ( + 0 # Percentage of workflows to experiment on when user is not opted-in. + ) + + # Add more fields as needed + + +class Settings(NamedTuple): + """ + Settings for the experiments that can be opted into. + """ + + experiments: Dict[str, Experiment] = {} + + class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -137,11 +182,14 @@ def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: def get_potential_pr_author( - gh: Github, repo: str, username: str, ref_type: str, ref_name: str + github_token: str, repo: str, username: str, ref_type: str, ref_name: str ) -> str: # If the trigger was a new tag added by a bot, this is a ciflow case # Fetch the actual username from the original PR. The PR number is # embedded in the tag name: ciflow// + + gh = get_gh_client(github_token) + if username == "pytorch-bot[bot]" and ref_type == "tag": split_tag = ref_name.split("/") if ( @@ -163,126 +211,233 @@ def get_potential_pr_author( def is_exception_branch(branch: str) -> bool: + """ + Branches that get opted out of all experiments and should always use Meta runners + """ return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} -def get_workflow_type(issue: Issue, workflow_requestors: Iterable[str]) -> str: +def load_yaml(yaml_text: str) -> Any: try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") - - if first_comment[0] == "!": - log.info("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META - elif first_comment[0] == "*": - log.info("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF - else: - all_opted_in_users = { - usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in first_comment.split() - } - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - if opted_in_requestors: - log.info( - f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." - ) - return WORKFLOW_LABEL_LF - else: - log.info( - f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." - ) - return WORKFLOW_LABEL_META + data = yaml.safe_load(yaml_text) + return data + except yaml.YAMLError as exc: + log.exception("Error loading YAML") + raise - except Exception as e: - log.error( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META +def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: + """ + Extracts the text with settings, if any, and the opted in users from the rollout state. -def get_optin_feature( - issue: Issue, workflow_requestors: Iterable[str], feature: str, fallback: str -) -> str: + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the rest is the users. + """ + rollout_state_parts = rollout_state.split("---") + if len(rollout_state_parts) >= 2: + return rollout_state_parts[0], rollout_state_parts[1] + else: + return "", rollout_state + + +class UserOptins(Dict[str, List[str]]): + """ + Dictionary of users with a list of features they have opted into + """ + + +def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: + """ + Parse the user opt-in text into a key value pair of username and the list of features they have opted into + + Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. + - Example line: "@User1,lf,split_build" + - A "#" prefix indicates the user is opted out of all experiments + + + """ + optins = UserOptins() + for user in user_optin_text.split("\n"): + user = user.strip("\r\n\t -") + if not user or not user.startswith("@"): + # Not a valid user. Skip + continue + + if user: + usr_name = user.split(",")[0].strip("@") + optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + + return optins + + +def parse_settings_from_text(settings_text: str) -> Settings: + """ + Parse the experiments from the issue body into a list of ExperimentSettings + """ try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") - userlist = {u.lstrip("#").strip("\n\t@ ") for u in first_comment.split()} - all_opted_in_users = set() - for user in userlist: - for i in user.split(","): - if i == feature: - all_opted_in_users.add(user.split(",")[0]) - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - - if opted_in_requestors: - log.info( - f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." - ) - return feature - else: + if settings_text: + # Escape the backtick as well so that we can have the settings in a code block on the GH issue + # for easy reading + # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on + # the backtick character in shell commands. + backtick = chr(96) # backtick character + settings_text = settings_text.strip(f"\r\n\t{backtick} ") + settings = load_yaml(settings_text) + + # For now we just load experiments. We can expand this if/when we add more settings + experiments = {} + + for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): + valid_settings = {} + for setting in exp_settings: + if setting not in Experiment._fields: + log.warning( + f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" + ) + else: + valid_settings[setting] = exp_settings[setting] + + experiments[exp_name] = Experiment(**valid_settings) + return Settings(experiments) + + except Exception: + log.exception("Failed to parse settings") + + return Settings() + + +def parse_settings(rollout_state: str) -> Settings: + """ + Parse settings, if any, from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the default values are used. + """ + settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) + return parse_settings_from_text(settings_text) + + +def parse_users(rollout_state: str) -> UserOptins: + """ + Parse users from the rollout state. + + """ + _, users_text = extract_settings_user_opt_in_from_text(rollout_state) + return parse_user_opt_in_from_text(users_text) + + +def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: + """ + Check if a user is opted into an experiment + """ + return experiment_name in user_optins.get(user, []) + + +def get_runner_prefix( + rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False +) -> str: + settings = parse_settings(rollout_state) + user_optins = parse_users(rollout_state) + + fleet_prefix = "" + prefixes = [] + for experiment_name, experiment_settings in settings.experiments.items(): + enabled = False + + # Is any workflow_requestor opted in to this experiment? + opted_in_users = [ + requestor + for requestor in workflow_requestors + if is_user_opted_in(requestor, user_optins, experiment_name) + ] + + if opted_in_users: log.info( - f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) - return fallback + enabled = True + elif experiment_settings.rollout_perc: + # If no user is opted in, then we randomly enable the experiment based on the rollout percentage + if random.uniform(0, 100) <= experiment_settings.rollout_perc: + log.info( + f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." + ) + enabled = True + + if enabled: + label = experiment_name + if experiment_name == LF_FLEET_EXPERIMENT: + # We give some special treatment to the "lf" experiment since determines the fleet we use + # - If it's enabled, then we always list it's prefix first + # - If we're in the canary branch, then we append ".c" to the lf prefix + if is_canary: + label += CANARY_FLEET_SUFFIX + fleet_prefix = label + else: + prefixes.append(label) - except Exception as e: + if len(prefixes) > 1: log.error( - f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' + f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" ) - return fallback + prefixes = prefixes[:1] + + # Fleet always comes first + if fleet_prefix: + prefixes.insert(0, fleet_prefix) + + return ".".join(prefixes) + "." if prefixes else "" + + +def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: + """ + Gets the first comment of the issue, which contains the desired rollout state. + + The default issue we use - https://github.com/pytorch/test-infra/issues/5132 + """ + gh = get_gh_client(github_token) + issue = get_issue(gh, repo, issue_num) + return str(issue.get_comments()[0].body.strip("\n\t ")) def main() -> None: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info(f"Exception branch: '{args.github_branch}', using meta runners") - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY + log.info( + f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." + ) + runner_label_prefix = DEFAULT_LABEL_PREFIX else: try: - gh = get_gh_client(args.github_token) - # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 - issue = get_issue(gh, args.github_issue_repo, args.github_issue) + rollout_state = get_rollout_state_from_issue( + args.github_token, args.github_issue_repo, args.github_issue + ) + username = get_potential_pr_author( - gh, + args.github_token, args.github_repo, args.github_actor, args.github_ref_type, args.github_branch, ) - label_type = get_workflow_type( - issue, - ( - args.github_issue_owner, - username, - ), - ) - runner_ami = get_optin_feature( - issue=issue, - workflow_requestors=( - args.github_issue_owner, - username, - ), - feature=RUNNER_AMI_AMZ2023, - fallback=RUNNER_AMI_LEGACY, + + is_canary = args.github_repo == "pytorch/pytorch-canary" + + runner_label_prefix = get_runner_prefix( + rollout_state, (args.github_issue_owner, username), is_canary ) + except Exception as e: log.error( - f"Failed to get issue. Falling back to meta runners. Exception: {e}" + f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" ) - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY - - # For Canary builds use canary runners - if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: - label_type = WORKFLOW_LABEL_LF_CANARY - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) - set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) if __name__ == "__main__": diff --git a/.github/scripts/s390x-ci/README.md b/.github/scripts/s390x-ci/README.md index f62b02e24aa3e9..94e4a85be43ceb 100644 --- a/.github/scripts/s390x-ci/README.md +++ b/.github/scripts/s390x-ci/README.md @@ -3,7 +3,7 @@ ## Install prerequisites. ``` -$ sudo dnf install docker +$ sudo dnf install podman podman-docker jq ``` ## Add services. @@ -27,23 +27,48 @@ $ sudo systemctl enable --now qemu-user-static ## Rebuild the image -In order to build or update the `iiilinuxibmcom/actions-runner` image, e.g. to get the -latest OS security fixes, use the following commands: +First build s390x builder image `docker.io/pytorch/manylinuxs390x-builder`, +using following commands: + +``` +$ cd ~ +$ git clone https://github.com/pytorch/pytorch +$ cd pytorch +$ git submodule update --init --recursive +$ GPU_ARCH_TYPE=cpu-s390x "$(pwd)/.ci/docker/manywheel/build.sh" manylinuxs390x-builder +$ docker image tag localhost/pytorch/manylinuxs390x-builder docker.io/pytorch/manylinuxs390x-builder:cpu-s390x +$ docker image save -o ~/manywheel-s390x.tar docker.io/pytorch/manylinuxs390x-builder:cpu-s390x +``` + +Next step is to build `actions-runner` image using: ``` $ cd self-hosted-builder $ sudo docker build \ - --build-arg repo=/ \ - --build-arg token=<***> \ --pull \ -f actions-runner.Dockerfile \ - -t iiilinuxibmcom/actions-runner \ + -t iiilinuxibmcom/actions-runner. \ . ``` -If it fails, ensure that selinux doesn't prevent it from working. +If there are failures, ensure that selinux doesn't prevent it from working. In worst case, selinux can be disabled with `setenforce 0`. +Now prepare all necessary files for runner registration: + +``` +$ sudo mkdir -p /etc/actions-runner/ +$ sudo chmod 700 /etc/actions-runner/ +$ sudo /bin/cp /etc/actions-runner//key_private.pem +$ sudo echo | sudo tee /etc/actions-runner//appid.env +$ sudo echo | sudo tee /etc/actions-runner//installid.env +$ sudo echo NAME= | sudo tee /etc/actions-runner//env +$ sudo echo ORG= | sudo tee -a /etc/actions-runner//env +$ cd self-hosted-builder +$ sudo /bin/cp helpers/*.sh /usr/local/bin/ +$ sudo chmod 755 /usr/local/bin/app_token.sh /usr/local/bin/gh_token_generator.sh +``` + ## Autostart the runner. ``` diff --git a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile index 416a6d8e50df5e..ee1db829fe66ce 100644 --- a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile +++ b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner.Dockerfile @@ -1,12 +1,12 @@ # Self-Hosted IBM Z Github Actions Runner. # Temporary image: amd64 dependencies. -FROM docker.io/amd64/ubuntu:22.04 as ld-prefix +FROM docker.io/amd64/ubuntu:23.10 as ld-prefix ENV DEBIAN_FRONTEND=noninteractive -RUN apt-get update && apt-get -y install ca-certificates libicu70 libssl3 +RUN apt-get update && apt-get -y install ca-certificates libicu72 libssl3 # Main image. -FROM docker.io/s390x/ubuntu:22.04 +FROM docker.io/s390x/ubuntu:23.10 # Packages for pytorch building and testing. ENV DEBIAN_FRONTEND=noninteractive @@ -16,6 +16,7 @@ RUN apt-get update && apt-get -y install \ gcc \ git \ jq \ + zip \ libxml2-dev \ libxslt-dev \ ninja-build \ @@ -43,24 +44,28 @@ COPY fs/ / RUN chmod +x /usr/bin/actions-runner /usr/bin/entrypoint +# install podman +RUN apt -y install podman podman-docker + # amd64 Github Actions Runner. RUN useradd -m actions-runner USER actions-runner WORKDIR /home/actions-runner -RUN curl -L https://github.com/actions/runner/releases/download/v2.309.0/actions-runner-linux-x64-2.309.0.tar.gz | tar -xz -# repository -ARG repo +# set up python virtual environment which is later used by runner. +# build workflows use "python -m pip install ...", +# and it doesn't work for non-root user +RUN virtualenv --system-site-packages venv -# repository token -ARG token +# copy prebuilt manywheel docker image for builds and tests +# build command is: +# GPU_ARCH_TYPE=cpu-s390x "$(pwd)/manywheel/build_docker.sh" +# and save command is: +# docker image save -o manywheel-s390x.tar pytorch/manylinuxs390x-builder:cpu-s390x +# +COPY --chown=actions-runner:actions-runner manywheel-s390x.tar /home/actions-runner/manywheel-s390x.tar -RUN ./config.sh \ - --unattended \ - --url "https://github.com/${repo}" \ - --token "${token}" \ - --no-default-labels \ - --labels self-hosted,linux.s390x +RUN curl -L https://github.com/actions/runner/releases/download/v2.317.0/actions-runner-linux-x64-2.317.0.tar.gz | tar -xz ENTRYPOINT ["/usr/bin/entrypoint"] CMD ["/usr/bin/actions-runner"] diff --git a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service index 158be9ccb6c1d0..323b00edc178bc 100644 --- a/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service +++ b/.github/scripts/s390x-ci/self-hosted-builder/actions-runner@.service @@ -8,12 +8,16 @@ StartLimitIntervalSec=0 Type=simple Restart=always ExecStartPre=-/usr/bin/docker rm --force actions-runner.%i +ExecStartPre=-/usr/local/bin/gh_token_generator.sh /etc/actions-runner/%i/appid.env /etc/actions-runner/%i/installid.env /etc/actions-runner/%i/key_private.pem /etc/actions-runner/%i/ghtoken.env ExecStart=/usr/bin/docker run \ + --env-file=/etc/actions-runner/%i/env \ + --env-file=/etc/actions-runner/%i/ghtoken.env \ --init \ --interactive \ --name=actions-runner.%i \ --rm \ - iiilinuxibmcom/actions-runner + --privileged \ + iiilinuxibmcom/actions-runner.%i ExecStop=/bin/sh -c "docker exec actions-runner.%i kill -INT -- -1" ExecStop=/bin/sh -c "docker wait actions-runner.%i" ExecStop=/bin/sh -c "docker rm actions-runner.%i" diff --git a/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner b/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner index 760784b21c3966..6d129d8656944b 100644 --- a/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner +++ b/.github/scripts/s390x-ci/self-hosted-builder/fs/usr/bin/actions-runner @@ -2,5 +2,45 @@ set -e -u +# first import docker image +if [ -f ./manywheel-s390x.tar ] ; then + docker image load --input manywheel-s390x.tar + docker image tag docker.io/pytorch/manylinuxs390x-builder:cpu-s390x docker.io/pytorch/manylinuxs390x-builder:cpu-s390x-main + rm -f manywheel-s390x.tar +fi + +token_file=registration-token.json + +# Generate registration token +curl \ + -X POST \ + -H "Accept: application/vnd.github.v3+json" \ + -H "Authorization: Bearer ${ACCESS_TOKEN}" \ + "https://api.github.com/orgs/${ORG}/actions/runners/registration-token" \ + -o "$token_file" + +unset ACCESS_TOKEN + +# register runner as ephemeral runner +# it does one job, stops and unregisters +registration_token=$(jq --raw-output .token "$token_file") + +./config.sh \ + --unattended \ + --ephemeral \ + --url "https://github.com/${ORG}" \ + --token "${registration_token}" \ + --name "${NAME}" \ + --no-default-labels \ + --labels self-hosted,linux.s390x + +unset registration_token +rm -f "$token_file" + +# enter into python virtual environment. +# build workflows use "python -m pip install ...", +# and it doesn't work for non-root user +source venv/bin/activate + # Run one job. -./run.sh --once +./run.sh diff --git a/.github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh b/.github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh new file mode 100755 index 00000000000000..eb483c197772ad --- /dev/null +++ b/.github/scripts/s390x-ci/self-hosted-builder/helpers/app_token.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# +# Request an ACCESS_TOKEN to be used by a GitHub APP +# Environment variable that need to be set up: +# * APP_ID, the GitHub's app ID +# * INSTALL_ID, the Github's app's installation ID +# * APP_PRIVATE_KEY, the content of GitHub app's private key in PEM format. +# +# https://github.com/orgs/community/discussions/24743#discussioncomment-3245300 +# + +set -o pipefail + +_GITHUB_HOST=${GITHUB_HOST:="github.com"} + +# If URL is not github.com then use the enterprise api endpoint +if [[ ${GITHUB_HOST} = "github.com" ]]; then + URI="https://api.${_GITHUB_HOST}" +else + URI="https://${_GITHUB_HOST}/api/v3" +fi + +API_VERSION=v3 +API_HEADER="Accept: application/vnd.github.${API_VERSION}+json" +CONTENT_LENGTH_HEADER="Content-Length: 0" +APP_INSTALLATIONS_URI="${URI}/app/installations" + + +# JWT parameters based off +# https://docs.github.com/en/developers/apps/building-github-apps/authenticating-with-github-apps#authenticating-as-a-github-app +# +# JWT token issuance and expiration parameters +JWT_IAT_DRIFT=60 +JWT_EXP_DELTA=600 + +JWT_JOSE_HEADER='{ + "alg": "RS256", + "typ": "JWT" +}' + + +build_jwt_payload() { + now=$(date +%s) + iat=$((now - JWT_IAT_DRIFT)) + jq -c \ + --arg iat_str "${iat}" \ + --arg exp_delta_str "${JWT_EXP_DELTA}" \ + --arg app_id_str "${APP_ID}" \ + ' + ($iat_str | tonumber) as $iat + | ($exp_delta_str | tonumber) as $exp_delta + | ($app_id_str | tonumber) as $app_id + | .iat = $iat + | .exp = ($iat + $exp_delta) + | .iss = $app_id + ' <<< "{}" | tr -d '\n' +} + +base64url() { + base64 | tr '+/' '-_' | tr -d '=\n' +} + +rs256_sign() { + openssl dgst -binary -sha256 -sign <(echo "$1") +} + +request_access_token() { + jwt_payload=$(build_jwt_payload) + encoded_jwt_parts=$(base64url <<<"${JWT_JOSE_HEADER}").$(base64url <<<"${jwt_payload}") + encoded_mac=$(echo -n "$encoded_jwt_parts" | rs256_sign "${APP_PRIVATE_KEY}" | base64url) + generated_jwt="${encoded_jwt_parts}.${encoded_mac}" + + auth_header="Authorization: Bearer ${generated_jwt}" + + app_installations_response=$(curl -sX POST \ + -H "${auth_header}" \ + -H "${API_HEADER}" \ + --header "X-GitHub-Api-Version: 2022-11-28" \ + --url "https://api.github.com/app/installations/${INSTALL_ID}/access_tokens" \ + ) + echo "$app_installations_response" | jq --raw-output '.token' +} + +request_access_token diff --git a/.github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh b/.github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh new file mode 100755 index 00000000000000..8f16974423dd77 --- /dev/null +++ b/.github/scripts/s390x-ci/self-hosted-builder/helpers/gh_token_generator.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$(dirname "$0") +APP_ID=$1 +INSTALL_ID=$2 +APP_PRIVATE_KEY=$3 +DST_FILE="$4" + +ACCESS_TOKEN="$(APP_ID="$(<"${APP_ID}")" INSTALL_ID="$(<"${INSTALL_ID}")" APP_PRIVATE_KEY="$(<"${APP_PRIVATE_KEY}")" "${SCRIPT_DIR}/app_token.sh")" +echo "ACCESS_TOKEN=${ACCESS_TOKEN}" > "${DST_FILE}" diff --git a/.github/scripts/sync_distributed_folder_prototype.sh b/.github/scripts/sync_distributed_folder_prototype.sh deleted file mode 100755 index d31fef5c79c428..00000000000000 --- a/.github/scripts/sync_distributed_folder_prototype.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -set -eoux pipefail - -SYNC_BRANCH=pytorch-stable-prototype - -git config user.email "fake@example.com" -git config user.name "PyTorch Stable Bot" - -git fetch origin main -git fetch origin "$SYNC_BRANCH" -git checkout "$SYNC_BRANCH" - -# Using a hardcoded SHA here is a massive speedup as we can skip the entire history of the pytorch GitHub repo. -# This specific SHA was chosen as it was before the "branch point" of the stable branch -for SHA in $(git log ba3b05fdf37ddbc3c301294d6a560a816335e717..origin/main --pretty="%h" -- torch/distributed torch/csrc/distributed test/distributed test/cpp/c10d benchmarks/distributed) -do - # `git merge-base --is-ancestor` exits with code 0 if the given SHA is an ancestor, and non-0 otherwise - if git merge-base --is-ancestor $SHA HEAD || [[ $(git log --grep="(cherry picked from commit $SHA") ]] - then - echo "Skipping $SHA" - continue - fi - echo "Copying $SHA" - git cherry-pick -x "$SHA" -X theirs - git reset --soft HEAD~1 - git add torch/distributed torch/csrc/distributed test/distributed test/cpp/c10d benchmarks/distributed - git checkout . - git commit --reuse-message=HEAD@{1} - git clean -f -done - -if [[ "${WITH_PUSH}" == true ]]; then - git push -fi diff --git a/.github/scripts/tag_docker_images_for_release.py b/.github/scripts/tag_docker_images_for_release.py index 62a4f21ae7dd20..73dfda03b1c82c 100644 --- a/.github/scripts/tag_docker_images_for_release.py +++ b/.github/scripts/tag_docker_images_for_release.py @@ -51,6 +51,8 @@ def main() -> None: for platform_image in platform_images: # type: ignore[attr-defined] for arch in platform_image.keys(): # type: ignore[attr-defined] + if arch == "cpu-s390x": + continue tag_image( platform_image[arch], # type: ignore[index] default_tag, diff --git a/.github/scripts/test_check_labels.py b/.github/scripts/test_check_labels.py index 2b2cd7b6c5204b..1c921f2eafa9a3 100644 --- a/.github/scripts/test_check_labels.py +++ b/.github/scripts/test_check_labels.py @@ -18,6 +18,7 @@ def mock_parse_args() -> object: class Object: def __init__(self) -> None: self.pr_num = 76123 + self.exit_non_zero = False return Object() diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py new file mode 100644 index 00000000000000..b20b8a68cbe084 --- /dev/null +++ b/.github/scripts/test_runner_determinator.py @@ -0,0 +1,237 @@ +from unittest import main, TestCase +from unittest.mock import Mock, patch + +import runner_determinator as rd + + +class TestRunnerDeterminatorIssueParser(TestCase): + def test_parse_settings(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + settings = rd.parse_settings(settings_text) + + self.assertTupleEqual( + rd.Experiment(rollout_perc=25), + settings.experiments["lf"], + "lf settings not parsed correctly", + ) + self.assertTupleEqual( + rd.Experiment(rollout_perc=0), + settings.experiments["otherExp"], + "otherExp settings not parsed correctly", + ) + + def test_parse_settings_in_code_block(self) -> None: + settings_text = """ + + ``` + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 0 + + ``` + + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + settings = rd.parse_settings(settings_text) + + self.assertTupleEqual( + rd.Experiment(rollout_perc=25), + settings.experiments["lf"], + "lf settings not parsed correctly", + ) + self.assertTupleEqual( + rd.Experiment(rollout_perc=0), + settings.experiments["otherExp"], + "otherExp settings not parsed correctly", + ) + + def test_parse_users(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + users = rd.parse_users(settings_text) + self.assertDictEqual( + {"User1": ["lf"], "User2": ["lf", "otherExp"]}, + users, + "Users not parsed correctly", + ) + + def test_parse_users_without_settings(self) -> None: + settings_text = """ + + @User1,lf + @User2,lf,otherExp + + """ + + users = rd.parse_users(settings_text) + self.assertDictEqual( + {"User1": ["lf"], "User2": ["lf", "otherExp"]}, + users, + "Users not parsed correctly", + ) + + +class TestRunnerDeterminatorGetRunnerPrefix(TestCase): + def test_opted_in_user(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("lf.", prefix, "Runner prefix not correct for User1") + + def test_opted_in_user_two_experiments(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User2"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + + @patch("random.uniform", return_value=50) + def test_opted_out_user(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User3"]) + self.assertEqual("", prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=10) + def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + # User3 is opted out, but is pulled into both experiments by the 10% rollout + prefix = rd.get_runner_prefix(settings_text, ["User3"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + def test_lf_prefix_always_comes_first(self) -> None: + settings_text = """ + experiments: + otherExp: + rollout_perc: 0 + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf + @User2,otherExp,lf + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User2"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + def test_ignores_commented_users(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + #@User1,lf + @User2,lf,otherExp + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("", prefix, "Runner prefix not correct for user") + + def test_ignores_extra_experiments(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + foo: + rollout_perc: 0 + --- + + Users: + @User1,lf,otherExp,foo + + """ + + prefix = rd.get_runner_prefix(settings_text, ["User1"]) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 066472a8042100..a1a60698af7950 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -36,6 +36,7 @@ import yaml from github_utils import ( + gh_close_pr, gh_fetch_json_list, gh_fetch_merge_base, gh_fetch_url, @@ -1174,11 +1175,11 @@ def merge_into( for pr in additional_merged_prs: pr.add_numbered_label(MERGE_COMPLETE_LABEL, dry_run) - if comment_id and self.pr_num: - # When the merge process reaches this part, we can assume that the commit - # has been successfully pushed to trunk - merge_commit_sha = repo.rev_parse(name=REMOTE_MAIN_BRANCH) + # When the merge process reaches this part, we can assume that the commit + # has been successfully pushed to trunk + merge_commit_sha = repo.rev_parse(name=self.default_branch()) + if comment_id and self.pr_num: # Finally, upload the record to Rockset. The list of pending and failed # checks are at the time of the merge save_merge_record( @@ -1203,6 +1204,17 @@ def merge_into( else: print("Missing comment ID or PR number, couldn't upload to Rockset") + # Usually Github will see that the commit has "resolves " in the + # commit message and close the PR, but sometimes it doesn't, leading to + # confusion. When it doesn't, we close it manually. + time.sleep(60) # Give Github some time to close the PR + manually_close_merged_pr( + pr=self, + additional_merged_prs=additional_merged_prs, + merge_commit_sha=merge_commit_sha, + dry_run=dry_run, + ) + def merge_changes( self, repo: GitRepo, @@ -1503,6 +1515,34 @@ def checks_to_markdown_bullets( ] +def manually_close_merged_pr( + pr: GitHubPR, + additional_merged_prs: List[GitHubPR], + merge_commit_sha: str, + dry_run: bool, +) -> None: + def _comment_and_close(pr: GitHubPR, comment: str) -> None: + pr = GitHubPR(pr.org, pr.project, pr.pr_num) # Refresh the PR + if not pr.is_closed(): + gh_post_pr_comment(pr.org, pr.project, pr.pr_num, comment, dry_run) + gh_close_pr(pr.org, pr.project, pr.pr_num, dry_run) + + message = ( + f"This PR (#{pr.pr_num}) was merged in {merge_commit_sha} but it is still open, likely due to a Github bug, " + "so mergebot is closing it manually. If you think this is a mistake, please feel free to reopen and contact Dev Infra." + ) + _comment_and_close(pr, message) + for additional_pr in additional_merged_prs: + message = ( + f"This PR (#{additional_pr.pr_num}) was merged as part of PR #{pr.pr_num} in the stack under {merge_commit_sha} " + "but it is still open, likely due to a Github bug, so mergebot is closing it manually. " + "If you think this is a mistake, please feel free to reopen and contact Dev Infra." + ) + _comment_and_close(additional_pr, message) + + print(f"PR {pr.pr_num} and all additional PRs in the stack have been closed.") + + @retries_decorator() def save_merge_record( comment_id: int, diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 38b90e919c3341..8db7da9456a6b9 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -1,7 +1,7 @@ {%- set upload_artifact_s3_action = "seemethere/upload-artifact-s3@v5" -%} {%- set download_artifact_s3_action = "seemethere/download-artifact-s3@v4" -%} -{%- set upload_artifact_action = "actions/upload-artifact@v3" -%} -{%- set download_artifact_action = "actions/download-artifact@v3" -%} +{%- set upload_artifact_action = "actions/upload-artifact@v4.4.0" -%} +{%- set download_artifact_action = "actions/download-artifact@v4.1.7" -%} {%- set timeout_minutes = 240 -%} diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index 42cfa8eb3b07ca..918151b486d661 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -68,17 +68,16 @@ jobs: needs: get-label-type with:!{{ upload.binary_env_as_input(config) }} {%- if "aarch64" in build_environment %} - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" {%- elif "s390x" in build_environment %} runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" {%- elif "conda" in build_environment and config["gpu_arch_type"] == "cuda" %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral {%- else %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" {%- endif %} build_name: !{{ config["build_name"] }} build_environment: !{{ build_environment }} @@ -103,7 +102,6 @@ jobs: build_name: !{{ config["build_name"] }} build_environment: !{{ build_environment }} {%- if "aarch64" in build_environment %} - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" {%- elif "s390x" in build_environment %} @@ -112,10 +110,10 @@ jobs: {%- elif config["gpu_arch_type"] == "rocm" %} runs_on: linux.rocm.gpu {%- elif config["gpu_arch_type"] == "cuda" %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu {%- else %} - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge {%- endif %} secrets: diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index af417054ab190c..272073dd902e6d 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -101,7 +101,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: !{{ config["build_name"] }} diff --git a/.github/templates/upload.yml.j2 b/.github/templates/upload.yml.j2 index f4d206d4d7539d..4494af7ac50b33 100644 --- a/.github/templates/upload.yml.j2 +++ b/.github/templates/upload.yml.j2 @@ -45,7 +45,7 @@ {%- if is_windows %} # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" {%- endif %} {%- else %} diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index d5aca578b9024a..9ba9af06a2ef45 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -53,10 +53,24 @@ env: !{{ common.concurrency(build_environment) }} jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + {%- for config in build_configs %} !{{ config["build_name"] }}-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + {%- if branches == "nightly" %} + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + {%- else %} + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + {%- endif %} timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config, True) }} {%- if config.pytorch_extra_install_requirements is defined and config.pytorch_extra_install_requirements|d('')|length > 0 %} @@ -85,15 +99,17 @@ jobs: !{{ common.wait_and_kill_ssh_windows('pytorch') }} !{{ config["build_name"] }}-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: !{{ config["build_name"] }}-build + needs: + - !{{ config["build_name"] }}-build + - get-label-type {%- if config["gpu_arch_type"] == "cuda" %} {%- if branches == "nightly" %} - runs-on: windows.8xlarge.nvidia.gpu + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" {%- else %} - runs-on: windows.8xlarge.nvidia.gpu.nonephemeral + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge.nonephemeral" {%- endif %} {%- else %} - runs-on: windows.4xlarge.nonephemeral + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" {%- endif %} timeout-minutes: !{{ common.timeout_minutes }} !{{ upload.binary_env(config, True) }} diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index deec5e35dda219..509312c30bdfea 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -283,7 +283,7 @@ jobs: # Ensure the working directory gets chowned back to the current user docker run --rm -v "${RUNNER_TEMP}/artifacts:/v" -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" . - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' }} with: name: ${{ inputs.build_name }} diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index fa590499d6d52b..5123889fb01bf8 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -210,7 +210,7 @@ jobs: - name: Download Build Artifacts if: ${{ steps.filter.outputs.is-test-matrix-empty == 'False' }} - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: ${{ inputs.build_name }} path: "${{ runner.temp }}/artifacts/" diff --git a/.github/workflows/_binary-upload.yml b/.github/workflows/_binary-upload.yml index 1c5a9294598b78..927f72c8d83897 100644 --- a/.github/workflows/_binary-upload.yml +++ b/.github/workflows/_binary-upload.yml @@ -126,7 +126,7 @@ jobs: # NB: When the previous build job is skipped, there won't be any artifacts and # this step will fail. Binary build jobs can only be skipped on CI, not nightly continue-on-error: true - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: ${{ inputs.build_name }} path: "${{ runner.temp }}/artifacts/" diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index bb91b43e319d25..95fe6bd1a3b50a 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -292,7 +292,7 @@ jobs: bundler-cache: true - name: Download arm64 artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: name: pytorch-ios-build-artifacts-arm64 diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 9d6563e2ea5af0..862ceceec181fb 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -62,49 +62,94 @@ jobs: """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default - https://github.com/pytorch/test-infra/issues/5132) as a user list to determine - which users will get their jobs to run on experimental runners. This user list - is also a comma separated list of additional features or experiments which the - user could be opted in to. + https://github.com/pytorch/test-infra/issues/5132) to define the configuration + of which runners should be used to run which job. + + The configuration has two parts, the settings and a list of opted-in users, + separated by a line containing "---". If the line is not present, the + settings are considered to be empty with only the second part, the user + list, defined. + + The first part is a YAML block that defines the rollout settings. This can be + used to define any settings that are needed to determine which runners to use. + It's fields are defined by the RolloutSettings class below. + + The second part is a list of users who are explicitly opted in to the LF fleet. + The user list is also a comma separated list of additional features or + experiments which the user could be opted in to. The user list has the following rules: - - Users are GitHub usernames with the @ prefix - - If the first line is a "*" then all users will use the new runners - - If the first line is a "!" then all users will use the old runners + - Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable - - A "#" prefix indicates the user is opted out of the new runners but is opting - into features/experiments. + - A "#" prefix opts the user out of all experiments + + Example config: + # A list of experiments that can be opted into. + # This defines the behavior they'll induce when opted into. + # Expected syntax is: + # [experiment_name]: # Name of the experiment. Also used for the label prefix. + # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. + + experiments: + lf: + rollout_percent: 25 + + --- - Example user list: + # Opt-ins: + # Users can opt into the LF fleet by adding their GitHub username to this list + # and specifying experiments to enable in a comma-separated list. + # Experiments should be from the above list. - @User1 - @User2,amz2023 - #@UserOptOutOfNewRunner,amz2023 + @User1,lf,split_build + @User2,lf + @User3,split_build """ import logging import os + import random from argparse import ArgumentParser from logging import LogRecord - from typing import Any, Iterable + from typing import Any, Dict, Iterable, List, NamedTuple, Tuple + import yaml from github import Auth, Github from github.Issue import Issue - WORKFLOW_LABEL_META = "" # use meta runners + DEFAULT_LABEL_PREFIX = "" # use meta runners WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation - RUNNER_AMI_LEGACY = "" - RUNNER_AMI_AMZ2023 = "amz2023" - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" + SETTING_EXPERIMENTS = "experiments" + + LF_FLEET_EXPERIMENT = "lf" + CANARY_FLEET_SUFFIX = ".c" + + + class Experiment(NamedTuple): + rollout_perc: float = ( + 0 # Percentage of workflows to experiment on when user is not opted-in. + ) + + # Add more fields as needed + + + class Settings(NamedTuple): + """ + Settings for the experiments that can be opted into. + """ + + experiments: Dict[str, Experiment] = {} + + class ColorFormatter(logging.Formatter): """Color codes the log messages based on the log level""" @@ -196,11 +241,14 @@ jobs: def get_potential_pr_author( - gh: Github, repo: str, username: str, ref_type: str, ref_name: str + github_token: str, repo: str, username: str, ref_type: str, ref_name: str ) -> str: # If the trigger was a new tag added by a bot, this is a ciflow case # Fetch the actual username from the original PR. The PR number is # embedded in the tag name: ciflow// + + gh = get_gh_client(github_token) + if username == "pytorch-bot[bot]" and ref_type == "tag": split_tag = ref_name.split("/") if ( @@ -222,130 +270,238 @@ jobs: def is_exception_branch(branch: str) -> bool: + """ + Branches that get opted out of all experiments and should always use Meta runners + """ return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} - def get_workflow_type(issue: Issue, workflow_requestors: Iterable[str]) -> str: + def load_yaml(yaml_text: str) -> Any: try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") - - if first_comment[0] == "!": - log.info("LF Workflows are disabled for everyone. Using meta runners.") - return WORKFLOW_LABEL_META - elif first_comment[0] == "*": - log.info("LF Workflows are enabled for everyone. Using LF runners.") - return WORKFLOW_LABEL_LF - else: - all_opted_in_users = { - usr_raw.strip("\n\t@ ").split(",")[0] - for usr_raw in first_comment.split() - } - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - if opted_in_requestors: - log.info( - f"LF Workflows are enabled for {', '.join(opted_in_requestors)}. Using LF runners." - ) - return WORKFLOW_LABEL_LF - else: - log.info( - f"LF Workflows are disabled for {', '.join(workflow_requestors)}. Using meta runners." - ) - return WORKFLOW_LABEL_META + data = yaml.safe_load(yaml_text) + return data + except yaml.YAMLError as exc: + log.exception("Error loading YAML") + raise - except Exception as e: - log.error( - f"Failed to get determine workflow type. Falling back to meta runners. Exception: {e}" - ) - return WORKFLOW_LABEL_META + def extract_settings_user_opt_in_from_text(rollout_state: str) -> Tuple[str, str]: + """ + Extracts the text with settings, if any, and the opted in users from the rollout state. - def get_optin_feature( - issue: Issue, workflow_requestors: Iterable[str], feature: str, fallback: str - ) -> str: + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the rest is the users. + """ + rollout_state_parts = rollout_state.split("---") + if len(rollout_state_parts) >= 2: + return rollout_state_parts[0], rollout_state_parts[1] + else: + return "", rollout_state + + + class UserOptins(Dict[str, List[str]]): + """ + Dictionary of users with a list of features they have opted into + """ + + + def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: + """ + Parse the user opt-in text into a key value pair of username and the list of features they have opted into + + Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. + - Example line: "@User1,lf,split_build" + - A "#" prefix indicates the user is opted out of all experiments + + + """ + optins = UserOptins() + for user in user_optin_text.split("\n"): + user = user.strip("\r\n\t -") + if not user or not user.startswith("@"): + # Not a valid user. Skip + continue + + if user: + usr_name = user.split(",")[0].strip("@") + optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + + return optins + + + def parse_settings_from_text(settings_text: str) -> Settings: + """ + Parse the experiments from the issue body into a list of ExperimentSettings + """ try: - first_comment = issue.get_comments()[0].body.strip("\n\t ") - userlist = {u.lstrip("#").strip("\n\t@ ") for u in first_comment.split()} - all_opted_in_users = set() - for user in userlist: - for i in user.split(","): - if i == feature: - all_opted_in_users.add(user.split(",")[0]) - opted_in_requestors = { - usr for usr in workflow_requestors if usr in all_opted_in_users - } - - if opted_in_requestors: - log.info( - f"Feature {feature} is enabled for {', '.join(opted_in_requestors)}. Using feature {feature}." - ) - return feature - else: + if settings_text: + # Escape the backtick as well so that we can have the settings in a code block on the GH issue + # for easy reading + # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on + # the backtick character in shell commands. + backtick = chr(96) # backtick character + settings_text = settings_text.strip(f"\r\n\t{backtick} ") + settings = load_yaml(settings_text) + + # For now we just load experiments. We can expand this if/when we add more settings + experiments = {} + + for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): + valid_settings = {} + for setting in exp_settings: + if setting not in Experiment._fields: + log.warning( + f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" + ) + else: + valid_settings[setting] = exp_settings[setting] + + experiments[exp_name] = Experiment(**valid_settings) + return Settings(experiments) + + except Exception: + log.exception("Failed to parse settings") + + return Settings() + + + def parse_settings(rollout_state: str) -> Settings: + """ + Parse settings, if any, from the rollout state. + + If the issue body contains "---" then the text above that is the settings + and the text below is the list of opted in users. + + If it doesn't contain "---" then the settings are empty and the default values are used. + """ + settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) + return parse_settings_from_text(settings_text) + + + def parse_users(rollout_state: str) -> UserOptins: + """ + Parse users from the rollout state. + + """ + _, users_text = extract_settings_user_opt_in_from_text(rollout_state) + return parse_user_opt_in_from_text(users_text) + + + def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: + """ + Check if a user is opted into an experiment + """ + return experiment_name in user_optins.get(user, []) + + + def get_runner_prefix( + rollout_state: str, workflow_requestors: Iterable[str], is_canary: bool = False + ) -> str: + settings = parse_settings(rollout_state) + user_optins = parse_users(rollout_state) + + fleet_prefix = "" + prefixes = [] + for experiment_name, experiment_settings in settings.experiments.items(): + enabled = False + + # Is any workflow_requestor opted in to this experiment? + opted_in_users = [ + requestor + for requestor in workflow_requestors + if is_user_opted_in(requestor, user_optins, experiment_name) + ] + + if opted_in_users: log.info( - f"Feature {feature} is disabled for {', '.join(workflow_requestors)}. Using fallback \"{fallback}\"." + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) - return fallback + enabled = True + elif experiment_settings.rollout_perc: + # If no user is opted in, then we randomly enable the experiment based on the rollout percentage + if random.uniform(0, 100) <= experiment_settings.rollout_perc: + log.info( + f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." + ) + enabled = True + + if enabled: + label = experiment_name + if experiment_name == LF_FLEET_EXPERIMENT: + # We give some special treatment to the "lf" experiment since determines the fleet we use + # - If it's enabled, then we always list it's prefix first + # - If we're in the canary branch, then we append ".c" to the lf prefix + if is_canary: + label += CANARY_FLEET_SUFFIX + fleet_prefix = label + else: + prefixes.append(label) - except Exception as e: + if len(prefixes) > 1: log.error( - f'Failed to determine if user has opted-in to feature {feature}. Using fallback "{fallback}". Exception: {e}' + f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" ) - return fallback + prefixes = prefixes[:1] + + # Fleet always comes first + if fleet_prefix: + prefixes.insert(0, fleet_prefix) + + return ".".join(prefixes) + "." if prefixes else "" + + + def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: + """ + Gets the first comment of the issue, which contains the desired rollout state. + + The default issue we use - https://github.com/pytorch/test-infra/issues/5132 + """ + gh = get_gh_client(github_token) + issue = get_issue(gh, repo, issue_num) + return str(issue.get_comments()[0].body.strip("\n\t ")) def main() -> None: args = parse_args() if args.github_ref_type == "branch" and is_exception_branch(args.github_branch): - log.info(f"Exception branch: '{args.github_branch}', using meta runners") - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY + log.info( + f"Exception branch: '{args.github_branch}', using Meta runners and no experiments." + ) + runner_label_prefix = DEFAULT_LABEL_PREFIX else: try: - gh = get_gh_client(args.github_token) - # The default issue we use - https://github.com/pytorch/test-infra/issues/5132 - issue = get_issue(gh, args.github_issue_repo, args.github_issue) + rollout_state = get_rollout_state_from_issue( + args.github_token, args.github_issue_repo, args.github_issue + ) + username = get_potential_pr_author( - gh, + args.github_token, args.github_repo, args.github_actor, args.github_ref_type, args.github_branch, ) - label_type = get_workflow_type( - issue, - ( - args.github_issue_owner, - username, - ), - ) - runner_ami = get_optin_feature( - issue=issue, - workflow_requestors=( - args.github_issue_owner, - username, - ), - feature=RUNNER_AMI_AMZ2023, - fallback=RUNNER_AMI_LEGACY, + + is_canary = args.github_repo == "pytorch/pytorch-canary" + + runner_label_prefix = get_runner_prefix( + rollout_state, (args.github_issue_owner, username), is_canary ) + except Exception as e: log.error( - f"Failed to get issue. Falling back to meta runners. Exception: {e}" + f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" ) - label_type = WORKFLOW_LABEL_META - runner_ami = RUNNER_AMI_LEGACY - # For Canary builds use canary runners - if args.github_repo == "pytorch/pytorch-canary" and label_type == WORKFLOW_LABEL_LF: - label_type = WORKFLOW_LABEL_LF_CANARY - - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, label_type) - set_github_output(GH_OUTPUT_KEY_AMI, runner_ami) + set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) if __name__ == "__main__": main() + EOF cat runner_determinator.py diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 6ee529f38b5c86..f0a1bb003de76d 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -11,6 +11,16 @@ on: required: true type: string description: What CUDA version to build with, "cpu" for none. + use-xpu: + required: false + type: boolean + default: false + description: If set, build with XPU support. + vc-year: + required: false + type: string + default: "2019" + description: The Visual Studio year to use for building. build-with-debug: required: false type: boolean @@ -141,7 +151,7 @@ jobs: SCCACHE_REGION: us-east-1 VC_PRODUCT: "BuildTools" VC_VERSION: "" - VC_YEAR: "2019" + VC_YEAR: "${{ inputs.vc-year }}" ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} @@ -149,6 +159,7 @@ jobs: DEBUG: ${{ inputs.build-with-debug && '1' || '0' }} TORCH_CUDA_ARCH_LIST: "8.6" USE_CUDA: ${{ inputs.cuda-version != 'cpu' && '1' || '0' }} + USE_XPU: ${{ inputs.use-xpu == true && '1' || '0' }} OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} run: | .ci/pytorch/win-build.sh diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml index 170b35107ac240..5146e7593a5fdd 100644 --- a/.github/workflows/build-libtorch-images.yml +++ b/.github/workflows/build-libtorch-images.yml @@ -29,9 +29,19 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + build-docker-cuda: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4", "12.1", "11.8"] @@ -66,7 +76,8 @@ jobs: .ci/docker/libtorch/build.sh libtorch-cxx11-builder:cuda${{matrix.cuda_version}} build-docker-rocm: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: rocm_version: ["6.1", "6.2"] @@ -101,7 +112,8 @@ jobs: .ci/docker/libtorch/build.sh libtorch-cxx11-builder:rocm${{matrix.rocm_version}} build-docker-cpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index a85d0c59335cd1..750ee99d52e38d 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -12,11 +12,13 @@ on: - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ paths: - '.ci/docker/manywheel/*' + - '.ci/docker/manywheel/build_scripts/*' - '.ci/docker/common/*' - .github/workflows/build-manywheel-images.yml pull_request: paths: - '.ci/docker/manywheel/*' + - '.ci/docker/manywheel/build_scripts/*' - '.ci/docker/common/*' - .github/workflows/build-manywheel-images.yml @@ -31,9 +33,19 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + build-docker-cuda: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: am2.linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4", "12.1", "11.8"] @@ -71,7 +83,8 @@ jobs: # NOTE: manylinux_2_28 are still experimental, see https://github.com/pytorch/pytorch/issues/123649 build-docker-cuda-manylinux_2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4", "12.1", "11.8"] @@ -108,7 +121,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28-builder:cuda${{matrix.cuda_version}} build-docker-cuda-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral" strategy: matrix: cuda_version: ["12.4"] @@ -141,7 +155,8 @@ jobs: .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cuda${{matrix.cuda_version}} build-docker-rocm: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: am2.linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral" strategy: matrix: rocm_version: ["6.1", "6.2"] @@ -176,7 +191,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux-builder:rocm${{matrix.rocm_version}} build-docker-cpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: am2.linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}am2.linux.9xlarge.ephemeral" steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main @@ -205,7 +221,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux-builder:cpu build-docker-cpu-manylinux_2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-manylinux_2_28 steps: @@ -236,7 +253,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28-builder:cpu build-docker-cpu-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-aarch64 steps: @@ -267,7 +285,8 @@ jobs: .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cpu-aarch64 build-docker-cpu-aarch64-2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.arm64.2xlarge + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.2xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-aarch64-2_28 steps: @@ -301,7 +320,8 @@ jobs: .ci/docker/manywheel/build.sh manylinux2_28_aarch64-builder:cpu-aarch64 build-docker-cpu-cxx11-abi: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" env: GPU_ARCH_TYPE: cpu-cxx11-abi steps: @@ -332,7 +352,8 @@ jobs: .ci/docker/manywheel/build.sh manylinuxcxx11-abi-builder:cpu-cxx11-abi build-docker-xpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} - runs-on: linux.9xlarge.ephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" env: GPU_ARCH_TYPE: xpu steps: diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 01f06cdd286cf4..8fe307c3b4c86e 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -13,7 +13,6 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt pull_request: paths: @@ -21,7 +20,6 @@ on: - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton.txt - - .ci/docker/ci_commit_pins/triton-rocm.txt - .ci/docker/ci_commit_pins/triton-xpu.txt concurrency: @@ -29,9 +27,19 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + build-wheel: name: "Build Triton Wheel" - runs-on: [self-hosted, linux.2xlarge] + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" strategy: fail-fast: false matrix: @@ -120,7 +128,7 @@ jobs: fi docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }} if-no-files-found: error @@ -157,7 +165,7 @@ jobs: aws-region: us-east-1 - name: Download Build Artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: # Download all available artifacts path: ${{ runner.temp }}/artifacts-all @@ -201,7 +209,8 @@ jobs: build-conda: name: "Build Triton Conda" - runs-on: [self-hosted, linux.2xlarge] + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" strategy: fail-fast: false matrix: @@ -253,7 +262,7 @@ jobs: docker exec -t "${container_name}" python /pytorch/.github/scripts/build_triton_wheel.py --build-conda --py-version="${PY_VERS}" $RELEASE docker exec -t "${container_name}" chown -R 1000.1000 /artifacts - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-triton-conda-${{ matrix.py_vers }} if-no-files-found: error @@ -273,7 +282,7 @@ jobs: - uses: actions/checkout@v3 - name: Download Build Artifacts - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4.1.7 with: # Download all available artifacts path: ${{ runner.temp }}/artifacts-all diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml index d638d588504f2e..8ad611bd7cc917 100644 --- a/.github/workflows/check-labels.yml +++ b/.github/workflows/check-labels.yml @@ -19,6 +19,10 @@ on: branches: [gh/**/base] workflow_dispatch: + inputs: + pr_number: + description: 'PR number to check labels for' + required: true concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} @@ -54,7 +58,7 @@ jobs: - name: Check labels env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PR_NUM: ${{ github.event.number }} + PR_NUM: ${{ github.event.number || github.event.inputs.pr_number }} run: | set -ex - python3 .github/scripts/check_labels.py "${PR_NUM}" + python3 .github/scripts/check_labels.py --exit-non-zero "${PR_NUM}" diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index a9836734bfe5f9..2c83b8cb571961 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -16,6 +16,15 @@ on: paths: [.github/workflows/create_release.yml] jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + release: if: ${{ github.repository == 'pytorch/pytorch' }} name: Create Release @@ -63,7 +72,7 @@ jobs: files: ${{env.PT_RELEASE_FILE}} - name: Upload source distribution to GHA artifacts for release tags if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4.4.0 with: name: ${{ env.PT_RELEASE_FILE }} path: ${{ env.PT_RELEASE_FILE }} @@ -73,14 +82,16 @@ jobs: upload_source_code_to_s3: if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} - runs-on: linux.2xlarge + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" environment: sourcecode-upload name: Upload source code to S3 for release tags permissions: id-token: write - needs: release + needs: + - get-label-type + - release steps: - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4.1.7 with: name: ${{ needs.release.outputs.pt_release_name }} - name: Configure AWS credentials(PyTorch account) diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 5967032e876b47..2e7f041a23e3d8 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -30,8 +30,18 @@ env: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + docker-build: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} + needs: get-label-type timeout-minutes: 240 strategy: fail-fast: false @@ -45,15 +55,15 @@ jobs: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9, - pytorch-linux-focal-py3.8-clang10, + pytorch-linux-focal-py3.9-clang10, pytorch-linux-focal-py3.11-clang10, pytorch-linux-focal-py3.12-clang10, pytorch-linux-focal-rocm-n-1-py3, pytorch-linux-focal-rocm-n-py3, - pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12, + pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12, pytorch-linux-focal-py3-clang9-android-ndk-r21e, - pytorch-linux-jammy-py3.8-gcc11, - pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks, + pytorch-linux-jammy-py3.9-gcc11, + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-xpu-2024.0-py3, pytorch-linux-jammy-py3-clang15-asan, @@ -68,7 +78,7 @@ jobs: - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 - runs-on: [self-hosted, "${{ matrix.runner }}"] + runs-on: "${{ needs.get-label-type.outputs.label-type }}${{ matrix.runner }}" env: DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/${{ matrix.docker-image-name }} steps: diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index e9c8db722432ea..41c5b40860303a 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -34,9 +34,19 @@ env: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + generate-matrix: if: github.repository_owner == 'pytorch' - runs-on: [self-hosted, linux.large] + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.large" outputs: matrix: ${{ steps.generate-matrix.outputs.matrix }} steps: @@ -54,10 +64,12 @@ jobs: build: if: ${{ github.repository == 'pytorch/pytorch' }} - runs-on: [self-hosted, linux.2xlarge] + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" environment: ${{ (github.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} timeout-minutes: 240 - needs: generate-matrix + needs: + - generate-matrix + - get-label-type strategy: matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }} fail-fast: false diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index f0f54b38337668..aeb85c90feb78f 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -58,9 +58,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -82,10 +82,10 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -105,6 +105,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-aarch64 secrets: @@ -127,9 +128,9 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -152,6 +153,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda-aarch64 secrets: @@ -173,9 +175,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -197,10 +199,10 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -220,6 +222,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-aarch64 secrets: @@ -242,9 +245,9 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -267,6 +270,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda-aarch64 secrets: @@ -288,9 +292,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -312,10 +316,10 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -335,6 +339,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-aarch64 secrets: @@ -357,9 +362,9 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -382,6 +387,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda-aarch64 secrets: @@ -403,9 +409,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -427,10 +433,10 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - runner_prefix: amz2023. runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: @@ -450,6 +456,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cpu-aarch64-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-aarch64 secrets: @@ -472,9 +479,9 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: amz2023. - runs_on: linux.arm64.m7g.4xlarge + runs_on: linux.arm64.m7g.4xlarge.ephemeral ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64 build_environment: linux-aarch64-binary-manywheel @@ -497,6 +504,7 @@ jobs: GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: pytorch/manylinuxaarch64-builder:cuda12.4-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda-aarch64 secrets: diff --git a/.github/workflows/generated-linux-binary-conda-nightly.yml b/.github/workflows/generated-linux-binary-conda-nightly.yml index b8a90c520abfa1..e4451fb1f9b74b 100644 --- a/.github/workflows/generated-linux-binary-conda-nightly.yml +++ b/.github/workflows/generated-linux-binary-conda-nightly.yml @@ -59,7 +59,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_9-cpu build_environment: linux-binary-conda secrets: @@ -82,7 +82,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -124,7 +124,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_9-cuda11_8 build_environment: linux-binary-conda @@ -149,7 +149,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -192,7 +192,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_9-cuda12_1 build_environment: linux-binary-conda @@ -217,7 +217,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -260,7 +260,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_9-cuda12_4 build_environment: linux-binary-conda @@ -285,7 +285,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: conda-py3_9-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -327,7 +327,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_10-cpu build_environment: linux-binary-conda secrets: @@ -350,7 +350,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -392,7 +392,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_10-cuda11_8 build_environment: linux-binary-conda @@ -417,7 +417,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -460,7 +460,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_10-cuda12_1 build_environment: linux-binary-conda @@ -485,7 +485,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -528,7 +528,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_10-cuda12_4 build_environment: linux-binary-conda @@ -553,7 +553,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: conda-py3_10-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -595,7 +595,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_11-cpu build_environment: linux-binary-conda secrets: @@ -618,7 +618,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -660,7 +660,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_11-cuda11_8 build_environment: linux-binary-conda @@ -685,7 +685,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -728,7 +728,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_11-cuda12_1 build_environment: linux-binary-conda @@ -753,7 +753,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -796,7 +796,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_11-cuda12_4 build_environment: linux-binary-conda @@ -821,7 +821,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: conda-py3_11-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -863,7 +863,7 @@ jobs: GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/conda-builder:cpu-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: conda-py3_12-cpu build_environment: linux-binary-conda secrets: @@ -886,7 +886,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cpu build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -928,7 +928,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_12-cuda11_8 build_environment: linux-binary-conda @@ -953,7 +953,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda11_8 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -996,7 +996,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_12-cuda12_1 build_environment: linux-binary-conda @@ -1021,7 +1021,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_1 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1064,7 +1064,7 @@ jobs: GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral build_name: conda-py3_12-cuda12_4 build_environment: linux-binary-conda @@ -1089,7 +1089,7 @@ jobs: DESIRED_PYTHON: "3.12" build_name: conda-py3_12-cuda12_4 build_environment: linux-binary-conda - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml index 748d20503d4eac..ad1098bf7d1702 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml @@ -55,7 +55,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -79,7 +79,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index 552c1e042b13cd..408106d0096aba 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -60,7 +60,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -84,7 +84,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cpu-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -128,7 +128,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -153,7 +153,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda11_8-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -198,7 +198,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -223,7 +223,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_1-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -268,7 +268,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -293,7 +293,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -338,7 +338,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_1-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -366,7 +366,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_1-shared-with-deps-cxx11-abi @@ -448,7 +448,7 @@ jobs: DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.2-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: cxx11-abi - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_2-shared-with-deps-cxx11-abi build_environment: linux-binary-libtorch-cxx11-abi secrets: @@ -476,7 +476,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_2-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml index 61847b43eaadfc..06c26961e9894b 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml @@ -55,7 +55,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -79,7 +79,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index e9f99382773c87..ee9f94c8ac6c25 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -60,7 +60,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -84,7 +84,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -128,7 +128,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -153,7 +153,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda11_8-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -198,7 +198,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -223,7 +223,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_1-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -268,7 +268,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -293,7 +293,7 @@ jobs: DESIRED_DEVTOOLSET: pre-cxx11 build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -338,7 +338,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_1-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -366,7 +366,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_1-shared-with-deps-pre-cxx11 @@ -448,7 +448,7 @@ jobs: DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main LIBTORCH_VARIANT: shared-with-deps DESIRED_DEVTOOLSET: pre-cxx11 - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: libtorch-rocm6_2-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: @@ -476,7 +476,7 @@ jobs: steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-rocm6_2-shared-with-deps-pre-cxx11 diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 15ae6d41909cf6..d87b832bf03cb0 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -54,8 +54,9 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -77,57 +78,11 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -146,8 +101,9 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -169,57 +125,11 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -238,8 +148,9 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -261,57 +172,11 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index c769a302d6a604..5a86872b3e288a 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -58,8 +58,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel secrets: @@ -79,10 +80,11 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -101,6 +103,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu secrets: @@ -123,8 +126,9 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -145,10 +149,11 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -168,6 +173,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-cxx11-abi secrets: @@ -190,8 +196,9 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -213,10 +220,11 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -236,6 +244,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_8 secrets: @@ -244,77 +253,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -329,8 +267,9 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -352,10 +291,11 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -375,6 +315,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_1 secrets: @@ -383,77 +324,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -468,8 +338,9 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -491,10 +362,11 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -514,6 +386,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda12_4 secrets: @@ -522,77 +395,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_9-rocm6_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -607,8 +409,9 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -631,11 +434,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_9-rocm6_1 @@ -692,6 +496,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_1 secrets: @@ -714,8 +519,9 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -738,11 +544,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.9" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_9-rocm6_2 @@ -799,6 +606,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-rocm6_2 secrets: @@ -820,8 +628,9 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-xpu build_environment: linux-binary-manywheel secrets: @@ -843,6 +652,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.9" permissions: id-token: write @@ -859,7 +669,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_9-xpu @@ -912,6 +722,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-xpu secrets: @@ -933,8 +744,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel secrets: @@ -954,10 +766,11 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -976,6 +789,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu secrets: @@ -998,8 +812,9 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -1020,10 +835,11 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1043,6 +859,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-cxx11-abi secrets: @@ -1065,8 +882,9 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1088,10 +906,11 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1111,6 +930,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_8 secrets: @@ -1119,7 +939,7 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda11_8-split-build: + manywheel-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1129,22 +949,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_10-cuda11_8-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-split-test: # Testing + manywheel-py3_10-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda11_8-split-build + - manywheel-py3_10-cuda12_1-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1153,44 +973,44 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8-split + build_name: manywheel-py3_10-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-split-upload: # Uploading + manywheel-py3_10-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda11_8-split-test + needs: manywheel-py3_10-cuda12_1-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8-split + build_name: manywheel-py3_10-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_1-build: + manywheel-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1200,230 +1020,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_10-cuda12_1 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-test: # Testing + manywheel-py3_10-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_10-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_1-full-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-full-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-full - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-full-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-full-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: False - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1-full - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_10-cuda12_4 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_4-build + - manywheel-py3_10-cuda12_4-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1436,10 +1048,11 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1459,6 +1072,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda12_4 secrets: @@ -1467,77 +1081,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_10-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-rocm6_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -1552,8 +1095,9 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -1576,11 +1120,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_10-rocm6_1 @@ -1637,6 +1182,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_1 secrets: @@ -1659,8 +1205,9 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -1683,11 +1230,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.10" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_10-rocm6_2 @@ -1744,6 +1292,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-rocm6_2 secrets: @@ -1765,8 +1314,9 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-xpu build_environment: linux-binary-manywheel secrets: @@ -1788,6 +1338,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.10" permissions: id-token: write @@ -1804,7 +1355,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_10-xpu @@ -1857,6 +1408,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-xpu secrets: @@ -1878,8 +1430,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel secrets: @@ -1899,10 +1452,11 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1921,6 +1475,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu secrets: @@ -1943,8 +1498,9 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -1965,10 +1521,11 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1988,6 +1545,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-cxx11-abi secrets: @@ -2010,8 +1568,9 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2033,10 +1592,11 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2056,154 +1616,16 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_11-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1 + build_name: manywheel-py3_11-cuda11_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_1-split-build: + manywheel-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2217,18 +1639,18 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_11-cuda12_1-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-split-test: # Testing + manywheel-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_1-split-build + - manywheel-py3_11-cuda12_1-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2241,20 +1663,20 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1-split + build_name: manywheel-py3_11-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-split-upload: # Uploading + manywheel-py3_11-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_1-split-test + needs: manywheel-py3_11-cuda12_1-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder @@ -2265,16 +1687,16 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1-split + build_name: manywheel-py3_11-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_4-build: + manywheel-py3_11-cuda12_1-full-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2284,21 +1706,21 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_11-cuda12_4 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1-full build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-test: # Testing + manywheel-py3_11-cuda12_1-full-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_4-build + - manywheel-py3_11-cuda12_1-full-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2307,42 +1729,44 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4 + build_name: manywheel-py3_11-cuda12_1-full build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-upload: # Uploading + manywheel-py3_11-cuda12_1-full-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_4-test + needs: manywheel-py3_11-cuda12_1-full-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4 + build_name: manywheel-py3_11-cuda12_1-full secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_4-split-build: + manywheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2356,18 +1780,18 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_11-cuda12_4-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-split-test: # Testing + manywheel-py3_11-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_4-split-build + - manywheel-py3_11-cuda12_4-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2380,20 +1804,20 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4-split + build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-split-upload: # Uploading + manywheel-py3_11-cuda12_4-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_4-split-test + needs: manywheel-py3_11-cuda12_4-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder @@ -2404,9 +1828,9 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4-split + build_name: manywheel-py3_11-cuda12_4 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} @@ -2427,8 +1851,9 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -2451,11 +1876,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_11-rocm6_1 @@ -2512,6 +1938,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_1 secrets: @@ -2534,8 +1961,9 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -2558,11 +1986,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.11" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_11-rocm6_2 @@ -2619,6 +2048,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-rocm6_2 secrets: @@ -2640,8 +2070,9 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-xpu build_environment: linux-binary-manywheel secrets: @@ -2663,6 +2094,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.11" permissions: id-token: write @@ -2679,7 +2111,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_11-xpu @@ -2732,6 +2164,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-xpu secrets: @@ -2753,8 +2186,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel secrets: @@ -2774,10 +2208,11 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2796,6 +2231,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu secrets: @@ -2818,8 +2254,9 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -2840,10 +2277,11 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2863,6 +2301,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-cxx11-abi secrets: @@ -2885,8 +2324,9 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2908,10 +2348,11 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2928,228 +2369,19 @@ jobs: # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_12-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_12-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1-split + build_name: manywheel-py3_12-cuda11_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda12_4-build: + manywheel-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3159,21 +2391,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_12-cuda12_4 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-test: # Testing + manywheel-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_4-build + - manywheel-py3_12-cuda12_1-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3182,42 +2415,44 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4 + build_name: manywheel-py3_12-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-upload: # Uploading + manywheel-py3_12-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_4-test + needs: manywheel-py3_12-cuda12_1-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4 + build_name: manywheel-py3_12-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda12_4-split-build: + manywheel-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3231,18 +2466,18 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_12-cuda12_4-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-split-test: # Testing + manywheel-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_4-split-build + - manywheel-py3_12-cuda12_4-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3255,20 +2490,20 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4-split + build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-split-upload: # Uploading + manywheel-py3_12-cuda12_4-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_4-split-test + needs: manywheel-py3_12-cuda12_4-test with: PYTORCH_ROOT: /pytorch BUILDER_ROOT: /builder @@ -3279,9 +2514,9 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True + use_split_build: False DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4-split + build_name: manywheel-py3_12-cuda12_4 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} @@ -3302,8 +2537,9 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_1 build_environment: linux-binary-manywheel secrets: @@ -3326,11 +2562,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_12-rocm6_1 @@ -3387,6 +2624,7 @@ jobs: GPU_ARCH_VERSION: 6.1 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_1 secrets: @@ -3409,8 +2647,9 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-rocm6_2 build_environment: linux-binary-manywheel secrets: @@ -3433,11 +2672,12 @@ jobs: GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.12" steps: - name: Setup ROCm uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_12-rocm6_2 @@ -3494,6 +2734,7 @@ jobs: GPU_ARCH_VERSION: 6.2 GPU_ARCH_TYPE: rocm DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.2-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-rocm6_2 secrets: @@ -3515,8 +2756,9 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-xpu build_environment: linux-binary-manywheel secrets: @@ -3538,6 +2780,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.12" permissions: id-token: write @@ -3554,7 +2797,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_12-xpu @@ -3607,6 +2850,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-xpu secrets: @@ -3628,8 +2872,9 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel secrets: @@ -3649,10 +2894,11 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3671,6 +2917,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu secrets: @@ -3693,8 +2940,9 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cpu-cxx11-abi build_environment: linux-binary-manywheel secrets: @@ -3715,10 +2963,11 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-cxx11-abi build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3738,6 +2987,7 @@ jobs: GPU_ARCH_TYPE: cpu-cxx11-abi DOCKER_IMAGE: pytorch/manylinuxcxx11-abi-builder:cpu-cxx11-abi-main DESIRED_DEVTOOLSET: cxx11-abi + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-cxx11-abi secrets: @@ -3760,8 +3010,9 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda11_8 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3783,10 +3034,11 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3806,6 +3058,7 @@ jobs: GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda11_8 secrets: @@ -3814,77 +3067,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda11_8-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_13-cuda11_8-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda11_8-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda11_8-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -3899,8 +3081,9 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_1 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -3922,10 +3105,11 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -3945,6 +3129,7 @@ jobs: GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_1 secrets: @@ -3953,77 +3138,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_1-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_13-cuda12_1-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_1-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_1-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_1-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_1-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_1-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_1-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -4038,8 +3152,9 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_4 build_environment: linux-binary-manywheel PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -4061,10 +3176,11 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4 build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -4084,6 +3200,7 @@ jobs: GPU_ARCH_VERSION: 12.4 GPU_ARCH_TYPE: cuda DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cuda12_4 secrets: @@ -4092,77 +3209,6 @@ jobs: conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_4-split-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build_name: manywheel-py3_13-cuda12_4-split - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_4-split-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_4-split-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_4-split - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_4-split-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_4-split-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_4-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-xpu-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml @@ -4176,8 +3222,9 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-xpu build_environment: linux-binary-manywheel secrets: @@ -4199,6 +3246,7 @@ jobs: GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.13" permissions: id-token: write @@ -4215,7 +3263,7 @@ jobs: - name: Login to Amazon ECR id: login-ecr uses: aws-actions/amazon-ecr-login@v2 - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: manywheel-py3_13-xpu @@ -4268,6 +3316,7 @@ jobs: DESIRED_CUDA: xpu GPU_ARCH_TYPE: xpu DOCKER_IMAGE: pytorch/manylinux2_28-builder:xpu-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-xpu secrets: diff --git a/.github/workflows/generated-linux-binary-manywheel-split-main.yml b/.github/workflows/generated-linux-binary-manywheel-split-main.yml new file mode 100644 index 00000000000000..9c2456e4632c75 --- /dev/null +++ b/.github/workflows/generated-linux-binary-manywheel-split-main.yml @@ -0,0 +1,182 @@ +# @generated DO NOT EDIT MANUALLY + +# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-binary-manywheel-split + + +on: + push: + branches: + - main + tags: + - 'ciflow/periodic/*' + workflow_dispatch: + +env: + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: linux-binary-manywheel-split + BUILDER_ROOT: /builder + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + PYTORCH_ROOT: /pytorch + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SKIP_ALL_TESTS: 0 +concurrency: + group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + manywheel-py3_9-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + + manywheel-py3_9-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml new file mode 100644 index 00000000000000..c3e0dbdd07c198 --- /dev/null +++ b/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml @@ -0,0 +1,1516 @@ +# @generated DO NOT EDIT MANUALLY + +# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-binary-manywheel-split + + +on: + push: + # NOTE: Meta Employees can trigger new nightlies using: https://fburl.com/trigger_pytorch_nightly_build + branches: + - nightly + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + - 'ciflow/binaries/*' + - 'ciflow/binaries_wheel/*' + workflow_dispatch: + +env: + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + ANACONDA_USER: pytorch + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: linux-binary-manywheel-split + BUILDER_ROOT: /builder + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + PYTORCH_ROOT: /pytorch + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SKIP_ALL_TESTS: 0 +concurrency: + group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + manywheel-py3_9-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda11_8 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_1 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_4 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_9-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-cpu + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_9-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.9" + build_name: manywheel-py3_9-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda11_8 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda11_8 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_1 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_1 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cuda12_4 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_4 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_10-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.10" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_10-cpu + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_10-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cpu + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_10-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_10-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.10" + build_name: manywheel-py3_10-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda11_8 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda11_8 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda12_1-full-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_1-full + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-full-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_1-full-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1-full + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_1-full-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_1-full-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_1-full + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cuda12_4 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_4 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_11-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.11" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_11-cpu + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_11-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cpu + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_11-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_11-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.11" + build_name: manywheel-py3_11-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda11_8 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda11_8 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_1 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_1 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cuda12_4 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_4 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_12-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.12" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_12-cpu + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_12-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cpu + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_12-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_12-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.12" + build_name: manywheel-py3_12-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda11_8 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda11_8-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda11_8-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda11_8 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda11_8-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cuda11_8-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda11_8 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cuda12_1-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda12_1 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_1-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda12_1-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_1 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_1-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cuda12_1-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_1 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cuda12_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cuda12_4 + build_environment: linux-binary-manywheel-split + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cuda12_4-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_4 + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge.nvidia.gpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cuda12_4-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + + manywheel-py3_13-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.13" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_13-cpu + build_environment: linux-binary-manywheel-split + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_13-cpu-build + - get-label-type + uses: ./.github/workflows/_binary-test-linux.yml + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cpu + build_environment: linux-binary-manywheel-split + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runs_on: linux.4xlarge + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_13-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: manywheel-py3_13-cpu-test + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + use_split_build: True + DESIRED_PYTHON: "3.13" + build_name: manywheel-py3_13-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 574a2f87c014cc..22468055434e86 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -58,6 +58,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.9" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -81,6 +82,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -103,6 +105,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cpu-s390x secrets: @@ -124,6 +127,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.10" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -147,6 +151,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -169,6 +174,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cpu-s390x secrets: @@ -190,6 +196,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.11" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -213,6 +220,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -235,6 +243,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cpu-s390x secrets: @@ -256,6 +265,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.12" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -279,6 +289,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -301,6 +312,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.12" build_name: manywheel-py3_12-cpu-s390x secrets: @@ -322,6 +334,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.13" runs_on: linux.s390x ALPINE_IMAGE: "docker.io/s390x/alpine" @@ -345,6 +358,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x build_environment: linux-s390x-binary-manywheel @@ -367,6 +381,7 @@ jobs: DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu-s390x DOCKER_IMAGE: pytorch/manylinuxs390x-builder:cpu-s390x-main + use_split_build: False DESIRED_PYTHON: "3.13" build_name: manywheel-py3_13-cpu-s390x secrets: diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index 0eabbcde433606..39579b08403dae 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -117,7 +117,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cpu @@ -232,7 +232,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cpu @@ -347,7 +347,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cpu @@ -462,7 +462,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cpu diff --git a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml index a428191d4a85e9..4343ac3e75ac33 100644 --- a/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-libtorch-cxx11-abi-nightly.yml @@ -49,7 +49,7 @@ jobs: DESIRED_DEVTOOLSET: cxx11-abi # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -121,7 +121,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-cxx11-abi diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 2900a33d5c6f0a..0a3716c7019b27 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -118,7 +118,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cpu @@ -234,7 +234,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_10-cpu @@ -350,7 +350,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_11-cpu @@ -466,7 +466,7 @@ jobs: # shellcheck disable=SC1091 source "${RUNNER_TEMP}/anaconda/bin/activate" "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cpu diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index da0dd504c75e5f..bcadb5d0fc4507 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} conda-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -123,7 +132,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cpu @@ -145,8 +154,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_9-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -202,7 +213,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cpu @@ -276,7 +287,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -366,7 +378,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cuda11_8 @@ -388,8 +400,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_9-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -446,7 +460,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cuda11_8 @@ -521,7 +535,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -611,7 +626,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cuda12_1 @@ -633,8 +648,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_9-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -691,7 +708,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cuda12_1 @@ -766,7 +783,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -856,7 +874,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_9-cuda12_4 @@ -878,8 +896,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_9-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_9-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_9-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -936,7 +956,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_9-cuda12_4 @@ -1011,7 +1031,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1100,7 +1121,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cpu @@ -1122,8 +1143,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_10-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1179,7 +1202,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cpu @@ -1253,7 +1276,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1343,7 +1367,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cuda11_8 @@ -1365,8 +1389,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_10-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1423,7 +1449,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cuda11_8 @@ -1498,7 +1524,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1588,7 +1615,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cuda12_1 @@ -1610,8 +1637,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_10-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1668,7 +1697,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cuda12_1 @@ -1743,7 +1772,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1833,7 +1863,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_10-cuda12_4 @@ -1855,8 +1885,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_10-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_10-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_10-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1913,7 +1945,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_10-cuda12_4 @@ -1988,7 +2020,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2077,7 +2110,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cpu @@ -2099,8 +2132,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_11-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2156,7 +2191,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cpu @@ -2230,7 +2265,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2320,7 +2356,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cuda11_8 @@ -2342,8 +2378,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_11-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2400,7 +2438,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cuda11_8 @@ -2475,7 +2513,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2565,7 +2604,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cuda12_1 @@ -2587,8 +2626,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_11-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2645,7 +2686,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cuda12_1 @@ -2720,7 +2761,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2810,7 +2852,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_11-cuda12_4 @@ -2832,8 +2874,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_11-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_11-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_11-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2890,7 +2934,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_11-cuda12_4 @@ -2965,7 +3009,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3054,7 +3099,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cpu @@ -3076,8 +3121,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - conda-py3_12-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3133,7 +3180,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cpu @@ -3207,7 +3254,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3297,7 +3345,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cuda11_8 @@ -3319,8 +3367,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_12-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3377,7 +3427,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cuda11_8 @@ -3452,7 +3502,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3542,7 +3593,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cuda12_1 @@ -3564,8 +3615,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_12-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3622,7 +3675,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cuda12_1 @@ -3697,7 +3750,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml conda-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3787,7 +3841,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: conda-py3_12-cuda12_4 @@ -3809,8 +3863,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 conda-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: conda-py3_12-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - conda-py3_12-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3867,7 +3923,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: conda-py3_12-cuda12_4 diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index 8ac413be0d65ea..85e2564d612f45 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -25,9 +25,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -42,7 +51,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -120,7 +129,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-debug @@ -142,8 +151,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-debug-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -158,7 +169,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -203,7 +214,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-debug diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index 60ba59556926f2..215dbe681896e2 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -49,7 +58,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -127,7 +136,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-debug @@ -149,8 +158,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-debug-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -165,7 +176,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -210,7 +221,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-debug @@ -279,7 +290,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cpu-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -288,7 +299,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda11_8-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -304,7 +316,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -382,7 +394,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda11_8-shared-with-deps-debug @@ -404,8 +416,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda11_8-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda11_8-shared-with-deps-debug-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda11_8-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -421,7 +435,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -466,7 +480,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda11_8-shared-with-deps-debug @@ -536,7 +550,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda11_8-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -545,7 +559,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_1-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -561,7 +576,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -639,7 +654,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_1-shared-with-deps-debug @@ -661,8 +676,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_1-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_1-shared-with-deps-debug-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_1-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -678,7 +695,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -723,7 +740,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_1-shared-with-deps-debug @@ -793,7 +810,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_1-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -802,7 +819,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_4-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -818,7 +836,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -896,7 +914,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_4-shared-with-deps-debug @@ -918,8 +936,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_4-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_4-shared-with-deps-debug-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_4-shared-with-deps-debug-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -935,7 +955,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -980,7 +1000,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_4-shared-with-deps-debug @@ -1050,7 +1070,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_4-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index ab00cdc8919ea9..7fd315028bb702 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -25,9 +25,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -42,7 +51,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -120,7 +129,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-release @@ -142,8 +151,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-release-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -158,7 +169,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -203,7 +214,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-release diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index 842de97a1fbe99..c3ce65daff7097 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} libtorch-cpu-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -49,7 +58,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -127,7 +136,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cpu-shared-with-deps-release @@ -149,8 +158,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cpu-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-release-build - runs-on: windows.4xlarge.nonephemeral + needs: + - libtorch-cpu-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -165,7 +176,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -210,7 +221,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cpu-shared-with-deps-release @@ -279,7 +290,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cpu-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -288,7 +299,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda11_8-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -304,7 +316,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -382,7 +394,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda11_8-shared-with-deps-release @@ -404,8 +416,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda11_8-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda11_8-shared-with-deps-release-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda11_8-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -421,7 +435,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -466,7 +480,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda11_8-shared-with-deps-release @@ -536,7 +550,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda11_8-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -545,7 +559,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_1-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -561,7 +576,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -639,7 +654,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_1-shared-with-deps-release @@ -661,8 +676,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_1-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_1-shared-with-deps-release-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_1-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -678,7 +695,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -723,7 +740,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_1-shared-with-deps-release @@ -793,7 +810,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_1-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -802,7 +819,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml libtorch-cuda12_4-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -818,7 +836,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -896,7 +914,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: libtorch-cuda12_4-shared-with-deps-release @@ -918,8 +936,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 libtorch-cuda12_4-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cuda12_4-shared-with-deps-release-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - libtorch-cuda12_4-shared-with-deps-release-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -935,7 +955,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -980,7 +1000,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: libtorch-cuda12_4-shared-with-deps-release @@ -1050,7 +1070,7 @@ jobs: LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason - DESIRED_PYTHON: "3.8" + DESIRED_PYTHON: "3.9" build_name: libtorch-cuda12_4-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 98bdf0c6776838..316329b46870fa 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -32,9 +32,18 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} wheel-py3_9-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -124,7 +133,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cpu @@ -146,8 +155,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_9-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -203,7 +214,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cpu @@ -277,7 +288,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -368,7 +380,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cuda11_8 @@ -390,8 +402,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_9-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -448,7 +462,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cuda11_8 @@ -523,7 +537,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -614,7 +629,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cuda12_1 @@ -636,8 +651,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_9-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -694,7 +711,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cuda12_1 @@ -769,7 +786,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_9-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -860,7 +878,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_9-cuda12_4 @@ -882,8 +900,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_9-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_9-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_9-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -940,7 +960,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_9-cuda12_4 @@ -1013,9 +1033,10 @@ jobs: conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cpu-build: + wheel-py3_9-xpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1023,11 +1044,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -1105,10 +1125,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cpu + name: wheel-py3_9-xpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1125,10 +1145,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cpu-test: # Testing + wheel-py3_9-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_9-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1136,10 +1158,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.10" + DESIRED_PYTHON: "3.9" steps: - name: Display EC2 information shell: bash @@ -1184,10 +1206,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cpu + name: wheel-py3_9-xpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -1235,30 +1257,31 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cpu-upload: # Uploading + wheel-py3_9-xpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cpu-test + needs: wheel-py3_9-xpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cpu + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.9" + build_name: wheel-py3_9-xpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda11_8-build: + wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1266,9 +1289,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -1349,10 +1371,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda11_8 + name: wheel-py3_10-cpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1369,10 +1391,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda11_8-test: # Testing + wheel-py3_10-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1380,9 +1404,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" steps: @@ -1429,10 +1452,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda11_8 + name: wheel-py3_10-cpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -1480,31 +1503,31 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda11_8-upload: # Uploading + wheel-py3_10-cpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda11_8-test + needs: wheel-py3_10-cpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda11_8 + build_name: wheel-py3_10-cpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda12_1-build: + wheel-py3_10-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1512,8 +1535,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1595,10 +1618,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda12_1 + name: wheel-py3_10-cuda11_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1615,10 +1638,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_1-test: # Testing + wheel-py3_10-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1626,8 +1651,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1675,10 +1700,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda12_1 + name: wheel-py3_10-cuda11_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -1726,31 +1751,32 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_1-upload: # Uploading + wheel-py3_10-cuda11_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda12_1-test + needs: wheel-py3_10-cuda11_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda12_1 + build_name: wheel-py3_10-cuda11_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda12_4-build: + wheel-py3_10-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1758,8 +1784,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1841,10 +1867,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda12_4 + name: wheel-py3_10-cuda12_1 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1861,10 +1887,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_4-test: # Testing + wheel-py3_10-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_10-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -1872,8 +1900,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -1921,10 +1949,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda12_4 + name: wheel-py3_10-cuda12_1 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -1972,31 +2000,32 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_4-upload: # Uploading + wheel-py3_10-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda12_4-test + needs: wheel-py3_10-cuda12_1-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda12_4 + build_name: wheel-py3_10-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cpu-build: + wheel-py3_10-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2004,10 +2033,11 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.11" + DESIRED_PYTHON: "3.10" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information @@ -2086,10 +2116,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cpu + name: wheel-py3_10-cuda12_4 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2106,10 +2136,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cpu-test: # Testing + wheel-py3_10-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_10-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2117,10 +2149,11 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.11" + DESIRED_PYTHON: "3.10" steps: - name: Display EC2 information shell: bash @@ -2165,10 +2198,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cpu + name: wheel-py3_10-cuda12_4 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2216,30 +2249,32 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cpu-upload: # Uploading + wheel-py3_10-cuda12_4-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cpu-test + needs: wheel-py3_10-cuda12_4-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cpu + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.10" + build_name: wheel-py3_10-cuda12_4 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda11_8-build: + wheel-py3_10-xpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2247,12 +2282,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + DESIRED_PYTHON: "3.10" steps: - name: Display EC2 information shell: bash @@ -2330,10 +2363,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda11_8 + name: wheel-py3_10-xpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2350,10 +2383,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda11_8-test: # Testing + wheel-py3_10-xpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_10-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2361,11 +2396,10 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.11" + DESIRED_PYTHON: "3.10" steps: - name: Display EC2 information shell: bash @@ -2410,10 +2444,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda11_8 + name: wheel-py3_10-xpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2461,31 +2495,31 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda11_8-upload: # Uploading + wheel-py3_10-xpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda11_8-test + needs: wheel-py3_10-xpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda11_8 + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.10" + build_name: wheel-py3_10-xpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_1-build: + wheel-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2493,9 +2527,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' @@ -2576,10 +2609,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda12_1 + name: wheel-py3_11-cpu retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2596,10 +2629,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_1-test: # Testing + wheel-py3_11-cpu-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_11-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2607,9 +2642,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" steps: @@ -2656,10 +2690,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_1 + name: wheel-py3_11-cpu path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2707,31 +2741,31 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_1-upload: # Uploading + wheel-py3_11-cpu-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_1-test + needs: wheel-py3_11-cpu-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_1 + build_name: wheel-py3_11-cpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_4-build: + wheel-py3_11-cuda11_8-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2739,8 +2773,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2822,10 +2856,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda12_4 + name: wheel-py3_11-cuda11_8 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2842,10 +2876,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_4-test: # Testing + wheel-py3_11-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_11-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_11-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2853,8 +2889,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2902,10 +2938,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_4 + name: wheel-py3_11-cuda11_8 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -2953,31 +2989,32 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_4-upload: # Uploading + wheel-py3_11-cuda11_8-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_4-test + needs: wheel-py3_11-cuda11_8-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_4 + build_name: wheel-py3_11-cuda11_8 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cpu-build: + wheel-py3_11-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -2985,10 +3022,11 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.12" + DESIRED_PYTHON: "3.11" PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information @@ -3067,10 +3105,10 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_12-cpu + name: wheel-py3_11-cuda12_1 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3087,10 +3125,12 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cpu-test: # Testing + wheel-py3_11-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cpu-build - runs-on: windows.4xlarge.nonephemeral + needs: + - wheel-py3_11-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3098,10 +3138,11 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.12" + DESIRED_PYTHON: "3.11" steps: - name: Display EC2 information shell: bash @@ -3146,10 +3187,10 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cpu + name: wheel-py3_11-cuda12_1 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Checkout PyTorch uses: malfet/checkout@silent-checkout @@ -3197,30 +3238,32 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cpu-upload: # Uploading + wheel-py3_11-cuda12_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cpu-test + needs: wheel-py3_11-cuda12_1-test with: PYTORCH_ROOT: ${{ github.workspace }}/pytorch BUILDER_ROOT: ${{ github.workspace }}/builder PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cpu + DESIRED_CUDA: cu121 + GPU_ARCH_VERSION: 12.1 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.11" + build_name: wheel-py3_11-cuda12_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda11_8-build: + wheel-py3_11-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3228,8 +3271,748 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_11-cuda12_4 + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-cuda12_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_11-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_11-cuda12_4 + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-cuda12_4-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_11-cuda12_4-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu124 + GPU_ARCH_VERSION: 12.4 + GPU_ARCH_TYPE: cuda + DESIRED_PYTHON: "3.11" + build_name: wheel-py3_11-cuda12_4 + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_11-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_11-xpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_11-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.11" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_11-xpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_11-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_11-xpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.11" + build_name: wheel-py3_11-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_12-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.12" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_12-cpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_12-cpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_12-cpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.12" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_12-cpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_12-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_12-cpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DESIRED_PYTHON: "3.12" + build_name: wheel-py3_12-cpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml + wheel-py3_12-cuda11_8-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cu118 + GPU_ARCH_VERSION: 11.8 GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3311,7 +4094,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cuda11_8 @@ -3333,8 +4116,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cuda11_8-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cuda11_8-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_12-cuda11_8-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3391,7 +4176,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cuda11_8 @@ -3466,7 +4251,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cuda12_1-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3557,7 +4343,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cuda12_1 @@ -3579,8 +4365,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cuda12_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cuda12_1-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_12-cuda12_1-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3637,7 +4425,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cuda12_1 @@ -3712,7 +4500,8 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cuda12_4-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: windows.4xlarge.nonephemeral + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3803,7 +4592,7 @@ jobs: shell: bash run: | "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4.4.0 if: always() with: name: wheel-py3_12-cuda12_4 @@ -3825,8 +4614,10 @@ jobs: .github\scripts\kill_active_ssh_sessions.ps1 wheel-py3_12-cuda12_4-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_12-cuda12_4-build - runs-on: windows.8xlarge.nvidia.gpu + needs: + - wheel-py3_12-cuda12_4-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }}/pytorch @@ -3883,7 +4674,7 @@ jobs: echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: name: wheel-py3_12-cuda12_4 @@ -3956,3 +4747,248 @@ jobs: conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml + wheel-py3_12-xpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.12" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_12-xpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_12-xpu-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - wheel-py3_12-xpu-build + - get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.12" + steps: + - name: Display EC2 information + shell: bash + run: | + set -euo pipefail + function get_ec2_metadata() { + # Pulled from instance metadata endpoint for EC2 + # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + category=$1 + curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + } + echo "ami-id: $(get_ec2_metadata ami-id)" + echo "instance-id: $(get_ec2_metadata instance-id)" + echo "instance-type: $(get_ec2_metadata instance-type)" + echo "system info $(uname -a)" + - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)" + uses: pytorch/test-infra/.github/actions/setup-ssh@main + continue-on-error: true + with: + github-secret: ${{ secrets.GITHUB_TOKEN }} + # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 + - name: Enable long paths on Windows + shell: powershell + run: | + Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1 + # Since it's just a defensive command, the workflow should continue even the command fails. This step can be + # removed once Windows Defender is removed from the AMI + - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch + continue-on-error: true + shell: powershell + run: | + Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore + # Let's both exclude the path and disable Windows Defender completely just to be sure + # that it doesn't interfere + Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}" + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: wheel-py3_12-xpu + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Populate binary env + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Test PyTorch binary + shell: bash + run: | + "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh" + - name: Wait until all sessions have drained + shell: powershell + working-directory: pytorch + if: always() + timeout-minutes: 120 + run: | + .github\scripts\wait_for_ssh_to_drain.ps1 + - name: Kill active ssh sessions if still around (Useful if workflow was cancelled) + shell: powershell + working-directory: pytorch + if: always() + run: | + .github\scripts\kill_active_ssh_sessions.ps1 + wheel-py3_12-xpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_12-xpu-test + with: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: xpu + GPU_ARCH_TYPE: xpu + DESIRED_PYTHON: "3.12" + build_name: wheel-py3_12-xpu + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/inductor-cu124.yml b/.github/workflows/inductor-cu124.yml index 39d1204a4e7033..950afbf0b591e8 100644 --- a/.github/workflows/inductor-cu124.yml +++ b/.github/workflows/inductor-cu124.yml @@ -18,11 +18,22 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_4-py3_10-gcc9-inductor-build: # Should be synced with the one in inductor.yml, but this doesn't run inductor_timm name: cuda12.4-py3.10-gcc9-sm86 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" sync-tag: linux-focal-cuda12_4-py3_10-gcc9-inductor-build build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml new file mode 100644 index 00000000000000..d31dbc5951ea12 --- /dev/null +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -0,0 +1,40 @@ +name: inductor-micro-benchmark-x86 + +on: + schedule: + - cron: 0 7 * * * + push: + tags: + - ciflow/inductor-micro-benchmark-cpu-x86/* + workflow_dispatch: + + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-build.yml + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + # Use metal host for benchmark jobs + test-matrix: | + { include: [ + { config: "inductor-micro-benchmark-cpu-x86", shard: 1, num_shards: 1, runner: "linux.24xl.spr-metal" }, + ]} + + linux-jammy-cpu-py3_9-gcc11-inductor-micro-benchmark-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor + uses: ./.github/workflows/_linux-test.yml + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build + with: + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} + use-gha: anything-non-empty-to-use-gha + timeout-minutes: 720 diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 431545ea6d0dc7..fad0538d107559 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -16,10 +16,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index a5e4ad1781aa06..a38bcadf7e5f7a 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -13,10 +13,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-perf-test-nightly-a10g.yml b/.github/workflows/inductor-perf-test-nightly-a10g.yml index 4acfd43aac3637..e42d7d4148c228 100644 --- a/.github/workflows/inductor-perf-test-nightly-a10g.yml +++ b/.github/workflows/inductor-perf-test-nightly-a10g.yml @@ -68,10 +68,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index 73e76486381935..39bc85752ae67a 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -50,10 +50,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-jammy-aarch64-py3_10-inductor-build: name: linux-jammy-aarch64-py3.10-inductor uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: linux.arm64.m7g.4xlarge build-environment: linux-jammy-aarch64-py3.10 docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 6011a133b257a9..83e8b26dd628e2 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -48,12 +48,23 @@ concurrency: permissions: read-all jobs: - linux-jammy-cpu-py3_8-gcc11-inductor-build: - name: linux-jammy-cpu-py3.8-gcc11-inductor + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: - build-environment: linux-jammy-py3.8-gcc11-build - docker-image-name: pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks test-matrix: | { include: [ { config: "inductor_huggingface_perf_cpu_x86", shard: 1, num_shards: 3, runner: "linux.24xl.spr-metal" }, @@ -74,32 +85,32 @@ jobs: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-test-nightly: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_8-gcc11-inductor-build + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build if: github.event.schedule == '0 7 * * *' with: - build-environment: linux-jammy-py3.8-gcc11-build + build-environment: linux-jammy-py3.9-gcc11-build dashboard-tag: training-false-inference-true-default-true-dynamic-true-aotinductor-true - docker-image: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.test-matrix }} + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} use-gha: anything-non-empty-to-use-gha timeout-minutes: 720 secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-test: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_8-gcc11-inductor-build + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-py3.8-gcc11-build + build-environment: linux-jammy-py3.9-gcc11-build dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-aotinductor-${{ inputs.aotinductor }} - docker-image: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.test-matrix }} + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} use-gha: anything-non-empty-to-use-gha timeout-minutes: 720 secrets: diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 2f129c52fe135b..5c7651d516d8b1 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -66,10 +66,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 7f83ce1eb54862..6bcb1be5ef0944 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -18,10 +18,21 @@ concurrency: permissions: read-all jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-periodic-dynamo-benchmarks-build: name: cuda12.1-py3.10-gcc9-sm86-periodic-dynamo-benchmarks uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' @@ -60,7 +71,9 @@ jobs: linux-focal-cuda12_1-py3_10-gcc9-inductor-build-gcp: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 2b1ca7bbd1b3e1..dd26a0a70fe312 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -22,11 +22,22 @@ concurrency: permissions: read-all jobs: - linux-focal-rocm6_1-py3_8-inductor-build: - name: rocm6.1-py3.8-inductor + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-focal-rocm6_1-py3_10-inductor-build: + name: rocm6.1-py3.10-inductor uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: - build-environment: linux-focal-rocm6.1-py3.8 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -34,14 +45,14 @@ jobs: { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" }, ]} - linux-focal-rocm6_1-py3_8-inductor-test: + linux-focal-rocm6_1-py3_10-inductor-test: permissions: id-token: write contents: read - name: rocm6.1-py3.8-inductor + name: rocm6.1-py3.10-inductor uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_1-py3_8-inductor-build + needs: linux-focal-rocm6_1-py3_10-inductor-build with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-inductor-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-inductor-build.outputs.test-matrix }} diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index c9987f0f130bd3..88ffe090fd8897 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -35,28 +35,28 @@ jobs: build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.12xlarge.nvidia.gpu" }, - { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -80,11 +80,11 @@ jobs: build-environment: linux-focal-cuda12.1-py3.12-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3.12-gcc9-inductor-benchmarks cuda-arch-list: '8.6' - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_12-gcc9-inductor-test: @@ -103,10 +103,10 @@ jobs: with: build-environment: linux-jammy-py3.12-gcc11 docker-image-name: pytorch-linux-jammy-py3.12-halide - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, + { config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, ]} linux-jammy-cpu-py3_12-inductor-halide-test: @@ -128,11 +128,11 @@ jobs: build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.6' - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} @@ -149,68 +149,68 @@ jobs: secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-build: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-build: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-py3.8-gcc11-build - docker-image-name: pytorch-linux-jammy-py3.8-gcc11-inductor-benchmarks - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + build-environment: linux-jammy-py3.9-gcc11-build + docker-image-name: pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_avx512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "inductor_avx512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_inductor_huggingface_amp_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_timm_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_timm_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "cpu_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.16xlarge.spr" }, - { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_huggingface_freezing", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_timm_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_timm_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_torchbench_amp_freezing", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, - { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.24xl.spr-metal" }, - { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_huggingface_freezing_avx2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_torchbench_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_torchbench_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_timm_freezing_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, - { config: "cpu_inductor_timm_freezing_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.10xlarge.avx2" }, + { config: "inductor_avx512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "inductor_avx512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_inductor_amp_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, + { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, + { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - linux-jammy-cpu-py3_8-gcc11-inductor-test: - name: linux-jammy-cpu-py3.8-gcc11-inductor + linux-jammy-cpu-py3_9-gcc11-inductor-test: + name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cpu-py3_8-gcc11-inductor-build + needs: linux-jammy-cpu-py3_9-gcc11-inductor-build with: - build-environment: linux-jammy-py3.8-gcc11-build - docker-image: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cpu-py3_8-gcc11-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-gcc11-build + docker-image: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c61a7703e7c22c..b0427b87bb16ab 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -28,7 +28,7 @@ jobs: needs: get-label-type with: timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -45,7 +45,7 @@ jobs: needs: get-label-type with: timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed @@ -60,7 +60,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -95,7 +95,7 @@ jobs: pr-sanity-checks: name: pr-sanity-checks needs: get-label-type - runs-on: [self-hosted, "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large"] + runs-on: [self-hosted, "${{ needs.get-label-type.outputs.label-type }}linux.large"] # Only run this on pull requests. This check is simple enough to be done without a Docker image if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') steps: @@ -116,7 +116,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: -1 submodules: true @@ -153,7 +153,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -192,7 +192,7 @@ jobs: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" docker-image: pytorch-linux-focal-linter fetch-depth: 0 ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -223,7 +223,7 @@ jobs: cache: pip - name: Install dependencies run: | - pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.1.* numpy==1.24.* + pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.2.* fbscribelogger==0.1.* numpy==1.24.* pip install torch --pre --index-url https://download.pytorch.org/whl/nightly/cpu/ - name: Run run_test.py (nonretryable) run: | diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 35f7a3ce116881..5057e9da2d1dd6 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -31,9 +31,9 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" - build-environment: linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 docs-push: name: docs push @@ -42,8 +42,8 @@ jobs: - docs-build - get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-jammy-py3.8-gcc11 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc11 docker-image: ${{ needs.docs-build.outputs.docker-image }} push: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || startsWith(github.event.ref, 'refs/tags/v') }} run-doxygen: true diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 714f1c5d641df8..5fe1784e59f6d0 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -52,14 +52,16 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-test: name: linux-focal-cuda12.1-py3.10-gcc9 @@ -77,19 +79,21 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.4-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_4-py3_10-gcc9-test: @@ -104,38 +108,38 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-build.outputs.test-matrix }} - parallelnative-linux-jammy-py3_8-gcc11-build: - name: parallelnative-linux-jammy-py3.8-gcc11 + parallelnative-linux-jammy-py3_9-gcc11-build: + name: parallelnative-linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: parallelnative-linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: parallelnative-linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} - parallelnative-linux-jammy-py3_8-gcc11-test: - name: parallelnative-linux-jammy-py3.8-gcc11 + parallelnative-linux-jammy-py3_9-gcc11-test: + name: parallelnative-linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-test.yml needs: - - parallelnative-linux-jammy-py3_8-gcc11-build + - parallelnative-linux-jammy-py3_9-gcc11-build - target-determination with: - build-environment: parallelnative-linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.parallelnative-linux-jammy-py3_8-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.parallelnative-linux-jammy-py3_8-gcc11-build.outputs.test-matrix }} + build-environment: parallelnative-linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.parallelnative-linux-jammy-py3_9-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.parallelnative-linux-jammy-py3_9-gcc11-build.outputs.test-matrix }} linux-focal-cuda11_8-py3_9-gcc9-build: name: linux-focal-cuda11.8-py3.9-gcc9 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.9-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 cuda-arch-list: 8.6 @@ -159,7 +163,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.10-gcc9-debug docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 build-with-debug: true @@ -250,7 +254,6 @@ jobs: name: buck-build-test uses: ./.github/workflows/_buck-build-test.yml with: - runner_prefix: "amz2023." test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1, runner: "ubuntu-latest" }, @@ -260,7 +263,6 @@ jobs: name: android-emulator-build-test uses: ./.github/workflows/_run_android_tests.yml with: - runner_prefix: "amz2023." test-matrix: | { include: [ { config: 'default', @@ -278,12 +280,12 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-vulkan-focal-py3.11-clang10 docker-image-name: pytorch-linux-focal-py3.11-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} linux-vulkan-focal-py3_11-clang10-test: @@ -295,13 +297,13 @@ jobs: docker-image: ${{ needs.linux-vulkan-focal-py3_11-clang10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-vulkan-focal-py3_11-clang10-build.outputs.test-matrix }} - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-rocm6.1-py3.8 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -310,34 +312,36 @@ jobs: { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu" }, ]} - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build: name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda12.1-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-experimental-split-build-test: @@ -357,7 +361,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda11.8-py3.9-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 32c9a4837e5cd2..a7c17117a8c477 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -43,72 +43,71 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} - linux-jammy-py3_8-gcc11-build: - name: linux-jammy-py3.8-gcc11 + linux-jammy-py3_9-gcc11-build: + name: linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-jammy-py3.8-gcc11 - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc11 + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit - linux-jammy-py3_8-gcc11-test: - name: linux-jammy-py3.8-gcc11 + linux-jammy-py3_9-gcc11-test: + name: linux-jammy-py3.9-gcc11 uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_8-gcc11-build + - linux-jammy-py3_9-gcc11-build - target-determination with: - build-environment: linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.test-matrix }} secrets: inherit linux-docs: name: linux-docs uses: ./.github/workflows/_docs.yml - needs: linux-jammy-py3_8-gcc11-build + needs: linux-jammy-py3_9-gcc11-build with: - runner_prefix: amz2023. - build-environment: linux-jammy-py3.8-gcc11 - docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }} + build-environment: linux-jammy-py3.9-gcc11 + docker-image: ${{ needs.linux-jammy-py3_9-gcc11-build.outputs.docker-image }} secrets: inherit - linux-jammy-py3_8-gcc11-no-ops: - name: linux-jammy-py3.8-gcc11-no-ops + linux-jammy-py3_9-gcc11-no-ops: + name: linux-jammy-py3.9-gcc11-no-ops uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-jammy-py3.8-gcc11-no-ops - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc11-no-ops + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} secrets: inherit - linux-jammy-py3_8-gcc11-pch: - name: linux-jammy-py3.8-gcc11-pch + linux-jammy-py3_9-gcc11-pch: + name: linux-jammy-py3.9-gcc11-pch uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-jammy-py3.8-gcc11-pch - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc11-pch + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -120,17 +119,17 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-clang15-asan docker-image-name: pytorch-linux-jammy-py3-clang15-asan test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build secrets: inherit @@ -149,63 +148,63 @@ jobs: sync-tag: asan-test secrets: inherit - linux-focal-py3_8-clang10-onnx-build: - name: linux-focal-py3.8-clang10-onnx + linux-focal-py3_9-clang10-onnx-build: + name: linux-focal-py3.9-clang10-onnx uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-py3.8-clang10-onnx + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-py3.9-clang10-onnx docker-image-name: pytorch-linux-focal-py3-clang10-onnx test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit - linux-focal-py3_8-clang10-onnx-test: - name: linux-focal-py3.8-clang10-onnx + linux-focal-py3_9-clang10-onnx-test: + name: linux-focal-py3.9-clang10-onnx uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_8-clang10-onnx-build + - linux-focal-py3_9-clang10-onnx-build - target-determination with: - build-environment: linux-focal-py3.8-clang10-onnx - docker-image: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang10-onnx + docker-image: ${{ needs.linux-focal-py3_9-clang10-onnx-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang10-onnx-build.outputs.test-matrix }} secrets: inherit - linux-focal-py3_8-clang10-build: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-build: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" - build-environment: linux-focal-py3.8-clang10 - docker-image-name: pytorch-linux-focal-py3.8-clang10 + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + build-environment: linux-focal-py3.9-clang10 + docker-image-name: pytorch-linux-focal-py3.9-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} - linux-focal-py3_8-clang10-test: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-test: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_8-clang10-build + - linux-focal-py3_9-clang10-build - target-determination with: - build-environment: linux-focal-py3.8-clang10 - docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang10 + docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} secrets: inherit linux-focal-py3_11-clang10-build: @@ -213,20 +212,20 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.11-clang10 docker-image-name: pytorch-linux-focal-py3.11-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -247,18 +246,18 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3.12-clang10 docker-image-name: pytorch-linux-focal-py3.12-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -278,14 +277,14 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -307,16 +306,16 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -338,7 +337,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-mobile-build docker-image-name: pytorch-linux-jammy-py3-clang15-asan build-generates-artifacts: false @@ -348,14 +347,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda-11_8-cudnn9-py3_8-clang12-build: - name: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + linux-jammy-cuda-11_8-cudnn9-py3_9-clang12-build: + name: linux-jammy-cuda11.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-jammy-cuda11.8-cudnn9-py3.8-clang12 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.8-clang12 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda11.8-cudnn9-py3.9-clang12 + docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn9-py3.9-clang12 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, @@ -367,7 +366,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-py3-clang9-mobile-custom-build-static docker-image-name: pytorch-linux-focal-py3-clang9-android-ndk-r21e build-generates-artifacts: false @@ -377,28 +376,28 @@ jobs: ]} secrets: inherit - linux-focal-py3_8-clang9-xla-build: - name: linux-focal-py3_8-clang9-xla + linux-focal-py3_9-clang9-xla-build: + name: linux-focal-py3_9-clang9-xla uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-py3.8-clang9-xla + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-py3.9-clang9-xla docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.1-lite test-matrix: | { include: [ - { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.12xlarge" }, + { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, ]} secrets: inherit - linux-focal-py3_8-clang9-xla-test: - name: linux-focal-py3_8-clang9-xla + linux-focal-py3_9-clang9-xla-test: + name: linux-focal-py3_9-clang9-xla uses: ./.github/workflows/_linux-test.yml - needs: linux-focal-py3_8-clang9-xla-build + needs: linux-focal-py3_9-clang9-xla-build with: - build-environment: linux-focal-py3.8-clang9-xla - docker-image: ${{ needs.linux-focal-py3_8-clang9-xla-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang9-xla-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang9-xla + docker-image: ${{ needs.linux-focal-py3_9-clang9-xla-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang9-xla-build.outputs.test-matrix }} secrets: inherit win-vs2019-cpu-py3-build: @@ -425,13 +424,13 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: cpu test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} secrets: inherit @@ -440,13 +439,13 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -455,13 +454,13 @@ jobs: uses: ./.github/workflows/_bazel-build-test.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.large" + runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" build-environment: linux-focal-cuda12.4-py3.10-gcc9-bazel-test docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-version: "12.4" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -489,14 +488,14 @@ jobs: ]} secrets: inherit - linux-jammy-py3_8-gcc11-mobile-lightweight-dispatch-build: - name: linux-jammy-py3.8-gcc11-mobile-lightweight-dispatch-build + linux-jammy-py3_9-gcc11-mobile-lightweight-dispatch-build: + name: linux-jammy-py3.9-gcc11-mobile-lightweight-dispatch-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-jammy-py3.8-gcc111-mobile-lightweight-dispatch-build - docker-image-name: pytorch-linux-jammy-py3.8-gcc11 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.9-gcc111-mobile-lightweight-dispatch-build + docker-image-name: pytorch-linux-jammy-py3.9-gcc11 build-generates-artifacts: false test-matrix: | { include: [ @@ -504,15 +503,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_1-py3_8-build: + linux-focal-rocm6_1-py3_10-build: # don't run build twice on main if: github.event_name == 'pull_request' - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-rocm6.1-py3.8 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -528,17 +527,17 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -559,12 +558,12 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-executorch docker-image-name: pytorch-linux-jammy-py3-clang12-executorch test-matrix: | { include: [ - { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit @@ -583,18 +582,18 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: True build-environment: linux-focal-py3.12-clang10 docker-image-name: pytorch-linux-focal-py3.12-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 1, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 2, num_shards: 3, runner: "amz2023.linux.2xlarge" }, - { config: "dynamo", shard: 3, num_shards: 3, runner: "amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 3, runner: "linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 3, runner: "linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 3, runner: "linux.2xlarge" }, + { config: "dynamo", shard: 1, num_shards: 3, runner: "linux.2xlarge" }, + { config: "dynamo", shard: 2, num_shards: 3, runner: "linux.2xlarge" }, + { config: "dynamo", shard: 3, num_shards: 3, runner: "linux.2xlarge" }, ]} secrets: inherit @@ -614,7 +613,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm75 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '7.5' diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 7486162e620a38..76b42498333a9d 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -3,18 +3,12 @@ name: rocm on: push: branches: -# - main + - main - release/* tags: - ciflow/rocm/* workflow_dispatch: schedule: - # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. - # Also run less frequently on weekends. - - cron: 45 0,8,16 * * 1-5 - - cron: 45 4 * * 0,6 - - cron: 45 4,12,20 * * 1-5 - - cron: 45 12 * * 0,6 - cron: 29 8 * * * # about 1:29am PDT concurrency: @@ -31,11 +25,11 @@ jobs: id-token: write contents: read - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6.1-py3.8 + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -48,16 +42,16 @@ jobs: { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" }, ]} - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} diff --git a/.github/workflows/runner-determinator-validator.yml b/.github/workflows/runner-determinator-validator.yml new file mode 100644 index 00000000000000..976fd2cccad5da --- /dev/null +++ b/.github/workflows/runner-determinator-validator.yml @@ -0,0 +1,40 @@ +name: Validate Runner Determinator Script is in Sync + +on: + # Run on PRs when the runner-determinator script is updated to ensure it's copies are kept in sync + pull_request: + paths: + - .github/workflows/_runner-determinator.yml + - .github/workflows/runner-determinator-validator.yml + - .github/scripts/runner_determinator.py + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +jobs: + check-runner-determinator: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Run Hardcode runner-determinator script + id: hardcode-script + run: | + # Extract the script content from _runner-determinator.yml and skip the first 10 spaces of each line + script_content=$(awk '/cat < runner_determinator.py/{flag=1;next}/EOF$/{flag=0}flag{print substr($0, 11)}' .github/workflows/_runner-determinator.yml) + + # Write the extracted script content to runner_determinator.py + echo "$script_content" > runner_determinator_workflow.py + + - name: Compare runner-determinator script embedded in workflow with checked in script + run: | + # Compare the extracted runner_determinator script with the existing one + # If this check fails, then make sure the contents of .github/scripts/runner_determinator.py is in sync with the + # version embedded into .github/workflows/_runner-determinator.yml + diff runner_determinator_workflow.py .github/scripts/runner_determinator.py + # Fail the job if the scripts are not identical + continue-on-error: false \ No newline at end of file diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 5f7179b6303267..426a4473cc034b 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -50,18 +50,20 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3-gcc9-slow-gradcheck docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 6, num_shards: 6, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 6, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 7, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 8, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3-gcc9-slow-gradcheck-test: @@ -81,14 +83,15 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-sm86-test: @@ -102,38 +105,38 @@ jobs: docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-sm86-build.outputs.test-matrix }} - linux-focal-py3_8-clang10-build: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-build: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-py3.8-clang10 - docker-image-name: pytorch-linux-focal-py3.8-clang10 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-py3.9-clang10 + docker-image-name: pytorch-linux-focal-py3.9-clang10 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, + { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} - linux-focal-py3_8-clang10-test: - name: linux-focal-py3.8-clang10 + linux-focal-py3_9-clang10-test: + name: linux-focal-py3.9-clang10 uses: ./.github/workflows/_linux-test.yml needs: - - linux-focal-py3_8-clang10-build + - linux-focal-py3_9-clang10-build - target-determination with: - build-environment: linux-focal-py3.8-clang10 - docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }} + build-environment: linux-focal-py3.9-clang10 + docker-image: ${{ needs.linux-focal-py3_9-clang10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-rocm6.1-py3.8 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -141,33 +144,33 @@ jobs: { config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, ]} - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} linux-jammy-py3_10-clang15-asan-build: name: linux-jammy-py3.10-clang15-asan uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-clang15-asan docker-image-name: pytorch-linux-jammy-py3-clang15-asan test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, - { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge" }, + { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} sync-tag: asan-build secrets: inherit diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 56e349dfa1b826..047e4a47ab97aa 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -21,7 +21,7 @@ on: jobs: stale: if: ${{ github.repository == 'pytorch/pytorch' }} - runs-on: linux.large.arc + runs-on: linux.large permissions: contents: read pull-requests: write diff --git a/.github/workflows/sync_distributed_folder_prototype.yml b/.github/workflows/sync_distributed_folder_prototype.yml deleted file mode 100644 index 5b7d8e6b1f17f4..00000000000000 --- a/.github/workflows/sync_distributed_folder_prototype.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: Sync Distributed Folder - -on: - #push: - # branches: - # - 'main' - # paths: - # - 'torch/distributed/**' - workflow_dispatch: - pull_request: - paths: - - '.github/scripts/sync_distributed_folder_prototype.sh' - - '.github/workflows/sync_distributed_folder_prototype.yml' - -env: - WITH_PUSH: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} - -permissions: - contents: write - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - sync: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - run: .github/scripts/sync_distributed_folder_prototype.sh diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index e8bf91c8d9ee91..373f464eae1390 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -10,8 +10,18 @@ permissions: contents: read jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + index: - runs-on: linux.g5.4xlarge.nvidia.gpu # 1 GPU A10G 24GB each + needs: get-label-type + runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" # 1 GPU A10G 24GB each environment: target-determinator-env steps: - name: Clone PyTorch diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index ac581496689977..4e2098e589238d 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -11,10 +11,21 @@ concurrency: cancel-in-progress: true jobs: + get-label-type: + name: get-label-type + uses: ./.github/workflows/_runner-determinator.yml + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + linux-focal-cuda12_1-py3_10-gcc9-torchbench-build-gcp: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index b9361e0d2eed16..f5a812abcac551 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -48,17 +48,17 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm86 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_4-py3_10-gcc9-sm86-test: @@ -80,7 +80,7 @@ jobs: build-environment: libtorch-linux-focal-cuda12.1-py3.7-gcc9 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 build-generates-artifacts: false - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "linux.4xlarge" test-matrix: | { include: [ @@ -93,7 +93,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-no-ops docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | @@ -109,7 +109,7 @@ jobs: build-environment: libtorch-linux-focal-cuda12.4-py3.7-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 build-generates-artifacts: false - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "linux.4xlarge" test-matrix: | { include: [ @@ -122,7 +122,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda12.4-py3.10-gcc9-no-ops docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | @@ -138,7 +138,7 @@ jobs: docker-image-name: pytorch-linux-focal-py3-clang9-android-ndk-r21e test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "amz2023.linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, ]} macos-py3-arm64-build: @@ -223,13 +223,13 @@ jobs: cuda-version: "12.1" runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" - linux-focal-rocm6_1-py3_8-build: - name: linux-focal-rocm6.1-py3.8 + linux-focal-rocm6_1-py3_10-build: + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." - build-environment: linux-focal-rocm6.1-py3.8 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-focal-rocm6.1-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -240,19 +240,19 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_1-py3_8-test: + linux-focal-rocm6_1-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.1-py3.8 + name: linux-focal-rocm6.1-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_1-py3_8-build + - linux-focal-rocm6_1-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.1-py3.8 - docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }} + build-environment: linux-focal-rocm6.1-py3.10 + docker-image: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm6_1-py3_10-build.outputs.test-matrix }} tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build: @@ -260,20 +260,22 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda12.4-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.2xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.4xlarge.nvidia.gpu" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build-test: @@ -292,15 +294,15 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}amz2023." + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true build-environment: linux-focal-cuda11.8-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}amz2023.linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, ]} linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build-test: diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index e82dc61370d3c6..f9e5593bf66ff1 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -2,7 +2,7 @@ name: Upload test stats on: workflow_run: - workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-cu124, inductor-rocm] + workflows: [pull, trunk, periodic, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm] types: - completed @@ -96,7 +96,7 @@ jobs: python3 -m tools.stats.check_disabled_tests --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" - name: Upload gpt-fast benchmark results to Rockset - if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && github.event.workflow_run.name == 'inductor-micro-benchmark' + if: steps.upload-s3.outcome && steps.upload-s3.outcome == 'success' && contains(github.event.workflow_run.name, 'inductor-micro-benchmark') env: ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }} WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 81fe225c5fc48a..17fd3e4dfc6b71 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -49,3 +49,12 @@ jobs: build-environment: linux-jammy-xpu-py3.9 docker-image: ${{ needs.linux-jammy-xpu-py3_9-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-xpu-py3_9-build.outputs.test-matrix }} + + windows-xpu-build: + name: win-vs2022-xpu-py3 + uses: ./.github/workflows/_win-build.yml + with: + build-environment: win-vs2022-xpu-py3 + cuda-version: cpu + use-xpu: true + vc-year: '2022' diff --git a/.lintrunner.toml b/.lintrunner.toml index dafe6207a7c9a6..9b43b382809212 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -138,8 +138,8 @@ init_command = [ '--dry-run={{DRYRUN}}', 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', - 'expecttest==0.1.6', - 'mypy==1.10.0', + 'expecttest==0.2.1', + 'mypy==1.11.2', 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.0 ; python_version >= "3.9"', 'types-requests==2.27.25', @@ -1482,7 +1482,7 @@ init_command = [ 'black==23.12.1', 'usort==1.0.8.post1', 'isort==5.13.2', - 'ruff==0.6.0', # sync with RUFF + 'ruff==0.6.3', # sync with RUFF ] is_formatter = true @@ -1567,7 +1567,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.6.0', # sync with PYFMT + 'ruff==0.6.3', # sync with PYFMT ] is_formatter = true diff --git a/BUILD.bazel b/BUILD.bazel index f079c76a7f7204..1018f7907adecd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -332,6 +332,7 @@ intern_build_aten_ops( "@fbgemm", "@mkl", "@sleef", + "@mkl_dnn//:mkl-dnn", ], ) @@ -574,7 +575,7 @@ cu_library( name = "torch_cuda", srcs = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], copts = torch_cuda_half_options, @@ -722,7 +723,7 @@ cc_library( "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/CMakeLists.txt b/CMakeLists.txt index 89ef59681bfff4..0318fcb4d1ec04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -496,10 +496,6 @@ if(USE_SYSTEM_LIBS) endif() endif() -# Used when building Caffe2 through setup.py -option(BUILDING_WITH_TORCH_LIBS - "Tell cmake if Caffe2 is being built alongside torch libs" ON) - # /Z7 override option When generating debug symbols, CMake default to use the # flag /Zi. However, it is not compatible with sccache. So we rewrite it off. # But some users don't use sccache; this override is for them. @@ -893,6 +889,16 @@ cmake_dependent_option( Will be disabled if not supported by the platform" ON "USE_CUDA OR USE_ROCM" OFF) +# +# Cannot be put into Dependencies.cmake due circular dependency: +# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake +# +if(USE_ROCM) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) + include(cmake/External/aotriton.cmake) + endif() +endif() + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/CODEOWNERS b/CODEOWNERS index 7b9db26104a9d5..bafce8f6f53521 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -57,7 +57,6 @@ nn/qat/ @jerryzh168 # Docker /.ci/docker/ @jeffdaily /.ci/docker/ci_commit_pins/triton.txt @desertfire @Chillee @eellison @shunting314 @bertmaher @jeffdaily @jataylo @jithunnair-amd @pruthvistony -/.ci/docker/ci_commit_pins/triton-rocm.txt @jeffdaily @jataylo @jithunnair-amd @pruthvistony /.ci/docker/ci_commit_pins/triton-xpu.txt @EikanWang @gujinghui # Github Actions diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 12f8222ff1bd48..99e47ef502998b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,7 +50,6 @@ aspects of contributing to PyTorch. - [Windows development tips](#windows-development-tips) - [Known MSVC (and MSVC with NVCC) bugs](#known-msvc-and-msvc-with-nvcc-bugs) - [Building on legacy code and CUDA](#building-on-legacy-code-and-cuda) -- [Running clang-tidy](#running-clang-tidy) - [Pre-commit tidy/linting hook](#pre-commit-tidylinting-hook) - [Building PyTorch with ASAN](#building-pytorch-with-asan) - [Getting `ccache` to work](#getting-ccache-to-work) @@ -1132,38 +1131,6 @@ CUDA, MSVC, and PyTorch versions are interdependent; please install matching ver Note: There's a [compilation issue](https://github.com/oneapi-src/oneDNN/issues/812) in several Visual Studio 2019 versions since 16.7.1, so please make sure your Visual Studio 2019 version is not in 16.7.1 ~ 16.7.5 -## Running clang-tidy - -[Clang-Tidy](https://clang.llvm.org/extra/clang-tidy/index.html) is a C++ -linter and static analysis tool based on the clang compiler. We run clang-tidy -in our CI to make sure that new C++ code is safe, sane and efficient. See the -[`clang-tidy` job in our GitHub Workflow's -lint.yml file](https://github.com/pytorch/pytorch/blob/main/.github/workflows/lint.yml) -for the simple commands we use for this. - -To run clang-tidy locally, follow these steps: - -1. Install clang-tidy. -We provide custom built binaries which have additional checks enabled. You can install it by running: -```bash -python3 -m tools.linter.clang_tidy.generate_build_files -``` -We currently only support Linux and MacOS (x86). - -2. Install clang-tidy driver script dependencies -```bash -pip3 install -r tools/linter/clang_tidy/requirements.txt -``` - -3. Run clang-tidy -```bash -# Run clang-tidy on the entire codebase -make clang-tidy -# Run clang-tidy only on your changes -make clang-tidy CHANGED_ONLY=--changed-only -``` -This internally invokes our driver script and closely mimics how clang-tidy is run on CI. - ## Pre-commit tidy/linting hook We use clang-tidy to perform additional diff --git a/README.md b/README.md index d2945520cd36f5..c7dd72ccc77c64 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,8 @@ Our trunk health (Continuous Integration signals) can be found at [hud.pytorch.o - [NVIDIA CUDA Support](#nvidia-cuda-support) - [AMD ROCm Support](#amd-rocm-support) - [Intel GPU Support](#intel-gpu-support) - - [Install Dependencies](#install-dependencies) - [Get the PyTorch Source](#get-the-pytorch-source) + - [Install Dependencies](#install-dependencies) - [Install PyTorch](#install-pytorch) - [Adjust Build Options (Optional)](#adjust-build-options-optional) - [Docker Image](#docker-image) @@ -161,9 +161,34 @@ They require JetPack 4.2 and above, and [@dusty-nv](https://github.com/dusty-nv) #### Prerequisites If you are installing from source, you will need: - Python 3.8 or later (for Linux, Python 3.8.1+ is needed) -- A compiler that fully supports C++17, such as clang or gcc (gcc 9.4.0 or newer is required) +- A compiler that fully supports C++17, such as clang or gcc (gcc 9.4.0 or newer is required, on Linux) +- Visual Studio or Visual Studio Build Tool on Windows + +\* PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, +Professional, or Community Editions. You can also install the build tools from +https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* +come with Visual Studio Code by default. + +\* We highly recommend installing an [Anaconda](https://www.anaconda.com/download) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. + +An example of environment setup is shown below: -We highly recommend installing an [Anaconda](https://www.anaconda.com/download) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. +* Linux: + +```bash +$ source /bin/activate +$ conda create -y -n +$ conda activate +``` + +* Windows: + +```bash +$ source \Scripts\activate.bat +$ conda create -y -n +$ conda activate +$ call "C:\Program Files\Microsoft Visual Studio\\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 +``` ##### NVIDIA CUDA Support If you want to compile with CUDA support, [select a supported version of CUDA from our support matrix](https://pytorch.org/get-started/locally/), then install the following: @@ -194,12 +219,23 @@ If you want to compile with Intel GPU support, follow these If you want to disable Intel GPU support, export the environment variable `USE_XPU=0`. Other potentially useful environment variables may be found in `setup.py`. +#### Get the PyTorch Source +```bash +git clone --recursive https://github.com/pytorch/pytorch +cd pytorch +# if you are updating an existing checkout +git submodule sync +git submodule update --init --recursive +``` + #### Install Dependencies **Common** ```bash conda install cmake ninja +# Run this command on native Windows +conda install rust # Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below pip install -r requirements.txt ``` @@ -235,15 +271,6 @@ pip install mkl-static mkl-include conda install -c conda-forge libuv=1.39 ``` -#### Get the PyTorch Source -```bash -git clone --recursive https://github.com/pytorch/pytorch -cd pytorch -# if you are updating an existing checkout -git submodule sync -git submodule update --init --recursive -``` - #### Install PyTorch **On Linux** @@ -284,13 +311,6 @@ python3 setup.py develop **On Windows** -Choose Correct Visual Studio Version. - -PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, -Professional, or Community Editions. You can also install the build tools from -https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* -come with Visual Studio Code by default. - If you want to build legacy python code, please refer to [Building on legacy code and CUDA](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#building-on-legacy-code-and-cuda) **CPU-only builds** @@ -298,7 +318,6 @@ If you want to build legacy python code, please refer to [Building on legacy cod In this mode PyTorch computations will run on your CPU, not your GPU ```cmd -conda activate python setup.py develop ``` @@ -471,7 +490,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. -A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito. +A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jamesb93), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. diff --git a/RELEASE.md b/RELEASE.md index 476cf199fdbbeb..59a3336b225331 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -50,6 +50,7 @@ Following is the Release Compatibility Matrix for PyTorch releases: | PyTorch version | Python | Stable CUDA | Experimental CUDA | Stable ROCm | | --- | --- | --- | --- | --- | +| 2.5 | >=3.9, <=3.12, (3.13 experimental) | CUDA 11.8, CUDA 12.1, CUDA 12.4, CUDNN 9.1.0.70 | None | ROCm 6.2 | | 2.4 | >=3.8, <=3.12 | CUDA 11.8, CUDA 12.1, CUDNN 9.1.0.70 | CUDA 12.4, CUDNN 9.1.0.70 | ROCm 6.1 | | 2.3 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 6.0 | | 2.2 | >=3.8, <=3.11, (3.12 experimental) | CUDA 11.8, CUDNN 8.7.0.84 | CUDA 12.1, CUDNN 8.9.2.26 | ROCm 5.7 | diff --git a/aten/src/ATen/AccumulateType.cpp b/aten/src/ATen/AccumulateType.cpp index c4623cc08629c7..5de47571ea6fed 100644 --- a/aten/src/ATen/AccumulateType.cpp +++ b/aten/src/ATen/AccumulateType.cpp @@ -9,6 +9,8 @@ c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { switch (device) { \ case DeviceType::CUDA: \ return CppTypeToScalarType>::value; \ + case DeviceType::XPU: \ + return CppTypeToScalarType>::value; \ case DeviceType::MPS: \ return CppTypeToScalarType>::value; \ default: \ diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6d9152a4d07df2..1896530c0af6bc 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER) endif() EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS}) -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp") diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 370400b4fa036c..d46abc2e211a9f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -52,7 +52,7 @@ class TORCH_API Context { } else if (device_type == at::kIPU) { return at::detail::getIPUHooks().getDefaultIPUGenerator(device.index()); } else if (device_type == at::kPrivateUse1) { - return at::GetPrivateUse1HooksInterface()->getDefaultGenerator( + return at::detail::getPrivateUse1Hooks().getDefaultGenerator( device.index()); } else { AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); @@ -91,7 +91,7 @@ class TORCH_API Context { } else if (device_type == at::kXPU) { return at::detail::getXPUHooks().getDeviceFromPtr(data); } else if (device_type == at::kPrivateUse1) { - return at::GetPrivateUse1HooksInterface()->getDeviceFromPtr(data); + return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data); } else { AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); } @@ -182,7 +182,7 @@ class TORCH_API Context { void lazyInitPrivateUse1() { c10::call_once(thp_init, [&] { if (isPrivateUse1HooksRegistered()) { - at::GetPrivateUse1HooksInterface()->initPrivateUse1(); + at::detail::getPrivateUse1Hooks().initPrivateUse1(); } }); } diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index b98d648c684546..db2eccf7954be0 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -299,6 +299,15 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + #define AT_DISPATCH_FLOATING_TYPES_AND4( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ @@ -307,6 +316,26 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \ SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) +#define AT_DISPATCH_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__) diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index a5e096ba2b79c7..a9d37cff78ca31 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -21,7 +21,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { } else if (at::globalContext().hasXPU()) { return at::detail::getXPUHooks().getPinnedMemoryAllocator(); } else if(at::isPrivateUse1HooksRegistered()) { - return at::GetPrivateUse1HooksInterface()->getPinnedMemoryAllocator(); + return at::detail::getPrivateUse1Hooks().getPinnedMemoryAllocator(); } else { TORCH_CHECK(false, "Need to provide pin_memory allocator to use pin memory.") } diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index dfd4928808602d..6f66e8065731a4 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -707,7 +707,12 @@ bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_ } bool isFunctionalTensor(const at::Tensor& tensor) { - return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); + return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); +} + +bool isBaseTensor(const at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor)); + return unsafeGetFunctionalWrapper(tensor)->isBaseTensor(); } bool isFunctionalTensor(const std::optional& t) { diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index afb3af5fc54f29..ed5daf90e5f44b 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -165,6 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { was_storage_changed_ = true; } + // A FunctionalTensor is considered a base if its not a view of another + // tensor. + bool isBaseTensor() const { + return view_metas_.empty(); + } + c10::SymInt get_storage_size(bool before) { return functional_storage_impl()->get_storage_size(before); } @@ -290,6 +296,8 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( return functional_impl; } +TORCH_API bool isBaseTensor(const at::Tensor& tensor); + TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); TORCH_API bool isFunctionalTensor(const std::optional& t); TORCH_API bool isFunctionalTensor( diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index c22f07866f7124..f1ec1a37bf82a6 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -1,6 +1,7 @@ #include -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) +#include #include #endif @@ -18,7 +19,13 @@ ThreadLocalState::ThreadLocalState() torch_dispatch_mode_state_(c10::impl::TorchDispatchModeTLS::get_state()), python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()), python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()), saved_tensors_default_hooks_state_(at::SavedTensorDefaultHooks::get_tls_state()), functionalization_reapply_views_state_(at::functionalization::impl::getFunctionalizationReapplyViewsTLS()), - saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) {} + saved_objects_(at::impl::ThreadLocalPythonObjects::get_state()) { +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) + for(uint8_t i=0; i(i)); + } +#endif +} void ThreadLocalState::set_grad_mode(bool enabled) { autograd_tls_.set_grad_mode(enabled); @@ -54,6 +61,11 @@ void ThreadLocalState::setThreadLocalState( at::functionalization::impl::setFunctionalizationReapplyViewsTLS(state.functionalization_reapply_views_state_); at::impl::ThreadLocalPythonObjects::set_state(state.saved_objects_); +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && !defined(BUILD_LITE_INTERPRETER) + for(uint8_t i=0; i(i), state.autocast_dtypes_[i]); + } +#endif } } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 8419499c3a563c..721ea9957513bf 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -78,6 +78,13 @@ class TORCH_API ThreadLocalState { // TLS for arbitrary python objects that is registered via hooks at::impl::ThreadLocalPythonObjects saved_objects_; +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ + !defined(BUILD_LITE_INTERPRETER) + // TLS for autocast dtypes + std::array + autocast_dtypes_; +#endif + friend class ThreadLocalStateGuard; }; diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index cf33d89e0814ea..f9a1ab0e1de8a7 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -105,6 +105,11 @@ std::string get_cpu_capability() { return "DEFAULT"; case native::CPUCapability::ZVECTOR: return "Z VECTOR"; +#elif defined(HAVE_SVE_CPU_DEFINITION) + case native::CPUCapability::DEFAULT: + return "DEFAULT"; + case native::CPUCapability::SVE256: + return "SVE256"; #else case native::CPUCapability::DEFAULT: return "NO AVX"; diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 10fb72796fc647..8ae66a30dcaf0f 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -69,7 +69,7 @@ thread_local std::array at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal at::kHalf, // XPU - at::ScalarType::Undefined, // MPS + at::kHalf, // MPS at::ScalarType::Undefined, // Meta (tensors with no data) at::kBFloat16, // HPU / HABANA at::ScalarType::Undefined, // SX-Aurora / NEC @@ -206,6 +206,118 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } +TORCH_LIBRARY_IMPL(_, AutocastMPS, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { + // lower_precision_fp + KERNEL_MPS2(_convolution, deprecated, lower_precision_fp) + KERNEL_MPS(_convolution, lower_precision_fp) + KERNEL_MPS(conv1d, lower_precision_fp) + KERNEL_MPS(conv2d, lower_precision_fp) + KERNEL_MPS(conv_tbc, lower_precision_fp) + KERNEL_MPS(conv_transpose1d, lower_precision_fp) + KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp) + KERNEL_MPS(convolution, lower_precision_fp) + KERNEL_MPS(_mps_convolution, lower_precision_fp) + KERNEL_MPS(prelu, lower_precision_fp) + KERNEL_MPS(addmm, lower_precision_fp) + KERNEL_MPS(addmv, lower_precision_fp) + KERNEL_MPS(addr, lower_precision_fp) + KERNEL_MPS(matmul, lower_precision_fp) + KERNEL_MPS(einsum, lower_precision_fp) + KERNEL_MPS(mm, lower_precision_fp) + KERNEL_MPS(mv, lower_precision_fp) + KERNEL_MPS(linear, lower_precision_fp) + KERNEL_MPS(addbmm, lower_precision_fp) + KERNEL_MPS(baddbmm, lower_precision_fp) + KERNEL_MPS(bmm, lower_precision_fp) + KERNEL_MPS(chain_matmul, lower_precision_fp) + KERNEL_MPS(linalg_multi_dot, lower_precision_fp) + KERNEL_MPS(lstm_cell, lower_precision_fp) + + // fp32 + KERNEL_MPS(acos, fp32) + KERNEL_MPS(asin, fp32) + KERNEL_MPS(cosh, fp32) + KERNEL_MPS(erfinv, fp32) + KERNEL_MPS(exp, fp32) + KERNEL_MPS(expm1, fp32) + KERNEL_MPS(log, fp32) + KERNEL_MPS(log10, fp32) + KERNEL_MPS(log2, fp32) + KERNEL_MPS(log1p, fp32) + KERNEL_MPS(reciprocal, fp32) + KERNEL_MPS(rsqrt, fp32) + KERNEL_MPS(sinh, fp32) + KERNEL_MPS(tan, fp32) + KERNEL_MPS2(pow, Tensor_Scalar, fp32) + KERNEL_MPS2(pow, Tensor_Tensor, fp32) + KERNEL_MPS2(pow, Scalar, fp32) + KERNEL_MPS(softplus, fp32) + KERNEL_MPS(layer_norm, fp32) + KERNEL_MPS(native_layer_norm, fp32) + KERNEL_MPS(group_norm, fp32) + KERNEL_MPS2(frobenius_norm, dim, fp32) + KERNEL_MPS(nuclear_norm, fp32) + KERNEL_MPS2(nuclear_norm, dim, fp32) + KERNEL_MPS(batch_norm, fp32) + KERNEL_MPS(cosine_similarity, fp32) + KERNEL_MPS(poisson_nll_loss, fp32) + KERNEL_MPS(cosine_embedding_loss, fp32) + KERNEL_MPS(nll_loss, fp32) + KERNEL_MPS(nll_loss2d, fp32) + KERNEL_MPS(hinge_embedding_loss, fp32) + KERNEL_MPS(kl_div, fp32) + KERNEL_MPS(l1_loss, fp32) + KERNEL_MPS(smooth_l1_loss, fp32) + KERNEL_MPS(huber_loss, fp32) + KERNEL_MPS(mse_loss, fp32) + KERNEL_MPS(margin_ranking_loss, fp32) + KERNEL_MPS(multilabel_margin_loss, fp32) + KERNEL_MPS(soft_margin_loss, fp32) + KERNEL_MPS(triplet_margin_loss, fp32) + KERNEL_MPS(multi_margin_loss, fp32) + KERNEL_MPS(binary_cross_entropy_with_logits, fp32) + KERNEL_MPS(dist, fp32) + KERNEL_MPS(pdist, fp32) + KERNEL_MPS(cdist, fp32) + KERNEL_MPS(renorm, fp32) + KERNEL_MPS(logsumexp, fp32) + + // fp32_set_opt_dtype + KERNEL_MPS(prod, fp32) + KERNEL_MPS2(prod, dim_int, fp32) + KERNEL_MPS2(prod, dim_Dimname, fp32) + KERNEL_MPS2(softmax, int, fp32) + KERNEL_MPS2(softmax, Dimname, fp32) + KERNEL_MPS2(log_softmax, int, fp32) + KERNEL_MPS2(log_softmax, Dimname, fp32) + KERNEL_MPS(cumprod, fp32) + KERNEL_MPS2(cumprod, dimname, fp32) + KERNEL_MPS(cumsum, fp32) + KERNEL_MPS2(cumsum, dimname, fp32) + KERNEL_MPS(linalg_vector_norm, fp32) + KERNEL_MPS(linalg_matrix_norm, fp32) + KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32) + KERNEL_MPS(sum, fp32) + KERNEL_MPS2(sum, dim_IntList, fp32) + KERNEL_MPS2(sum, dim_DimnameList, fp32) + // + // promote + KERNEL_MPS(addcdiv, promote) + KERNEL_MPS(addcmul, promote) + KERNEL_MPS(atan2, promote) + KERNEL_MPS(bilinear, promote) + KERNEL_MPS(cross, promote) + KERNEL_MPS(dot, promote) + KERNEL_MPS(grid_sampler, promote) + KERNEL_MPS(index_put, promote) + KERNEL_MPS(tensordot, promote) + KERNEL_MPS(scatter_add, promote) +} + TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } @@ -224,6 +336,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(linalg_vecdot, lower_precision_fp) KERNEL_CPU(baddbmm, lower_precision_fp) KERNEL_CPU(addmm, lower_precision_fp) + KERNEL_CPU(_addmm_activation, lower_precision_fp) KERNEL_CPU(addbmm, lower_precision_fp) KERNEL_CPU(linear, lower_precision_fp) KERNEL_CPU(_convolution, deprecated, lower_precision_fp) diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 6c9f8c556aef8b..95f1dd2ca0c009 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -145,6 +145,8 @@ inline bool is_autocast_eligible( return tensor.is_xla() && tensor.is_floating_point(); case c10::DeviceType::PrivateUse1: return tensor.is_privateuseone() && tensor.is_floating_point(); + case c10::DeviceType::MPS: + return tensor.is_mps() && tensor.is_floating_point(); default: return false; } @@ -168,6 +170,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type( return DispatchKey::AutocastXLA; case c10::DeviceType::PrivateUse1: return DispatchKey::AutocastPrivateUse1; + case c10::DeviceType::MPS: + return DispatchKey::AutocastMPS; default: throw std::runtime_error( "unknown device type for autocast in get_autocast_dispatch_key_from_device_type"); @@ -178,7 +182,7 @@ inline bool is_autocast_available(c10::DeviceType device_type) { if (device_type == at::kCPU || device_type == at::kCUDA || device_type == at::kXPU || device_type == at::kIPU || device_type == at::kHPU || device_type == at::kXLA || - device_type == at::kPrivateUse1) { + device_type == at::kPrivateUse1 || device_type == at::kMPS) { return true; } else { return false; @@ -745,6 +749,27 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) +// KERNEL_MPS registration for AutocastMPS +#define KERNEL_MPS(OP, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP), \ + &WrapFunction< \ + CastPolicy::POLICY, \ + DeviceType::MPS, \ + decltype(ATEN_FN(OP)), \ + decltype(ATEN_FN(OP)), \ + &ATEN_FN(OP)>::type::call); + +#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ + &WrapFunction< \ + CastPolicy::POLICY, \ + DeviceType::MPS, \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + &ATEN_FN2(OP, OVERLOAD)>::type::call); + // Op lists for different policies. // To make sure other backends can reuse the policy op list. #define AT_FORALL_LOWER_PRECISION_FP(_) \ diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 5af7f5f564a7dc..1d5fbacdcb8476 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -30,20 +32,17 @@ struct HostBlock { ska::flat_hash_set streams_; // streams on which the block was used }; -/** - * ComparatorSize is used for lookup support in the set of host memory blocks - * using the block size. - */ template -struct ComparatorSize { - bool operator()(const B* a, const B* b) const { - if (a->size_ != b->size_) { - return a->size_ < b->size_; - } - return (uintptr_t)a->ptr_ < (uintptr_t)b->ptr_; - } +struct alignas(64) FreeBlockList { + std::mutex mutex_; + std::deque list_; }; +namespace { + // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes + constexpr size_t MAX_SIZE_INDEX = 64; +} + /** * Note [HostAllocator design] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -81,8 +80,7 @@ struct ComparatorSize { * to abstract the caching mechanism. Any backend needs to provide a customized * implementation by specializing its own public functions and the related * runtime functions. Its template parameter S represents runtime Stream, E - * denotes runtime Event, B indicates the fundamental memory block, and C - * signifies the sorting compartor algorithm for the memory blocks. + * denotes runtime Event, B indicates the fundamental memory block. * * For the interface, we provide a CachingHostAllocatorInterface struct as an * interface. Any backend needs to derive its own host allocator from this @@ -111,9 +109,19 @@ struct ComparatorSize { template < typename S, typename E, - typename B = HostBlock, - typename C = ComparatorSize> + typename B = HostBlock> struct CachingHostAllocatorImpl { + CachingHostAllocatorImpl() { + // Launch the background thread and process events in a loop. + if (pinned_use_background_threads()) { + getBackgroundThreadPool()->run([&]() { + while (true) { + process_events(); + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + } + } virtual ~CachingHostAllocatorImpl() = default; public: @@ -123,16 +131,34 @@ struct CachingHostAllocatorImpl { return {nullptr, nullptr}; } - process_events(); + // If we are using background threads, we can process events in the + // background. + if (!pinned_use_background_threads()) { + process_events(); + } + + // Round up the allocation to the nearest power of two to improve reuse. + // These power of two sizes are also used to index into the free list. + size_t roundSize = c10::llvm::PowerOf2Ceil(size); // First, try to allocate from the free list - auto* block = get_free_block(size); + auto* block = get_free_block(roundSize); if (block) { return {block->ptr_, reinterpret_cast(block)}; } - // Round up the allocation to the nearest power of two to improve reuse. - size_t roundSize = c10::llvm::PowerOf2Ceil(size); + // Check in the recently freed blocks with pending events to see if we + // can reuse them. Call get_free_block again after processing events + if (pinned_use_background_threads()) { + process_events_for_specific_size(roundSize); + block = get_free_block(roundSize); + if (block) { + return {block->ptr_, reinterpret_cast(block)}; + } + } + + // Slow path: if we can't allocate from the cached free list, we need + // to create a new block. void* ptr = nullptr; allocate_host_memory(roundSize, &ptr); @@ -171,8 +197,9 @@ struct CachingHostAllocatorImpl { } if (!events) { - std::lock_guard g(free_list_mutex_); - free_list_.insert(block); + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); } else { // restore these events that record by used streams. std::lock_guard g(events_mutex_); @@ -218,22 +245,32 @@ struct CachingHostAllocatorImpl { // Remove all elements from the free list, remove them from the blocks // list, and free the associated pinned memory allocation. This requires - // concurrently holding both the free list mutex and the blocks mutex, and + // concurrently holding both the free list mutexes and the blocks mutex, and // is the only function that concurrently holds multiple mutexes. - std::lock(free_list_mutex_, blocks_mutex_); - std::lock_guard gf(free_list_mutex_, std::adopt_lock); - std::lock_guard gb(blocks_mutex_, std::adopt_lock); - - std::vector blocks_to_remove(free_list_.begin(), free_list_.end()); - free_list_.clear(); - for (auto* block : blocks_to_remove) { - blocks_.erase(block); - ptr_to_block_.erase(block->ptr_); - free_block(block); - delete block; + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + std::vector blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end()); + free_list_[i].list_.clear(); + for (auto* block : blocks_to_remove) { + blocks_.erase(block); + ptr_to_block_.erase(block->ptr_); + free_block(block); + delete block; + } } } + inline size_t size_index(size_t size) { + return c10::llvm::Log2_64_Ceil(size); + } + + virtual bool pinned_use_background_threads() { + return false; + } + virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data"); } @@ -246,19 +283,33 @@ struct CachingHostAllocatorImpl { } virtual B* get_free_block(size_t size) { - std::lock_guard g(free_list_mutex_); - B key(size); - auto it = free_list_.lower_bound(&key); - if (it != free_list_.end()) { - B* block = *it; + auto index = size_index(size); + std::lock_guard g(free_list_[index].mutex_); + if (free_list_[index].list_.size() > 0) { + B* block = free_list_[index].list_.back(); + free_list_[index].list_.pop_back(); block->allocated_ = true; - free_list_.erase(it); return block; } return nullptr; } virtual void process_events() { + // process all events until the last unready event, not for specific size. + process_events_for_specific_size(-1); + } + + // If size is -1, process all events from backwards until the last unready + // event. Otherwise, process events for a specific size and on first ready block + // is found, add it to the free list and return. + virtual void process_events_for_specific_size(int64_t size) { + size_t event_count = 0; + size_t max_events = 0; + { + std::lock_guard g(events_mutex_); + max_events = events_.size(); + } + while (true) { // Avoid calling cudaEventDestroy while holding a mutex, so move // intermediate events out of the lock into this object. @@ -276,6 +327,25 @@ struct CachingHostAllocatorImpl { return; } + if (size != -1) { + if (event_count++ > max_events) { + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + return; + } + if (size != (int64_t)processed->second->size_) { + // if we are processing a specific size, and the size of the block + // doesn't match, we can't use it. + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + continue; + } + } + // otherwise, query the event { // now, see if we can handle this element @@ -284,9 +354,14 @@ struct CachingHostAllocatorImpl { // push the event onto the back if it's not ready. { std::lock_guard g(events_mutex_); - events_.push_back(std::move(*processed)); + if (size == -1) { + events_.push_back(std::move(*processed)); + return; + } else { + events_.push_front(std::move(*processed)); + continue; + } } - return; } } @@ -304,49 +379,57 @@ struct CachingHostAllocatorImpl { } if (available) { - std::lock_guard g(free_list_mutex_); - free_list_.insert(block); + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); + if (size != -1) { + return; + } } } } - /* These following functions are runtime-related. */ - - // Allocate page-locked memory on the host. - virtual void allocate_host_memory(size_t size, void** ptr) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "Not implemented for allocate_host_memory"); + TaskThreadPool* getBackgroundThreadPool() { + static TaskThreadPool* pool = new TaskThreadPool(1); + return pool; } - // Free block and release the pointer contained in block. - virtual void free_block(B* block) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); - } + /* These following functions are runtime-related. */ - // Record an event on stream and store event into events. - virtual void record_stream(std::optional>& events, S stream) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); - } + // Allocate page-locked memory on the host. + virtual void allocate_host_memory(size_t size, void** ptr) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Not implemented for allocate_host_memory"); + } - // Query event if it is completed. - virtual bool query_event(E& event) { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); - } + // Free block and release the pointer contained in block. + virtual void free_block(B* block) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); + } - alignas(64) std::mutex blocks_mutex_; - ska::flat_hash_set blocks_; // block list - ska::flat_hash_map ptr_to_block_; + // Record an event on stream and store event into events. + virtual void record_stream(std::optional>& events, S stream) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); + } - // Note: sharding this mutex seems to be profitable in heavily multi-threaded - // scenarios. - alignas(64) std::mutex free_list_mutex_; - // Note: an alternative datastructure can yield significant wins here in - // microbenchmarks. - std::set free_list_; // free list + // Query event if it is completed. + virtual bool query_event(E& event) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); + } - alignas(64) std::mutex events_mutex_; - std::deque> events_; // event queue paired with block -}; + alignas(64) std::mutex blocks_mutex_; + ska::flat_hash_set blocks_; // block list + ska::flat_hash_map ptr_to_block_; + + // We keep free list as a vector of free lists, one for each power of two + // size. This allows us to quickly find a free block of the right size. + // We use deque to store per size free list and guard the list with its own + // mutex. + alignas(64) std::vector> free_list_ = std::vector>(MAX_SIZE_INDEX); + + alignas(64) std::mutex events_mutex_; + std::deque> events_; // event queue paired with block + }; template struct CachingHostAllocatorInterface : public at::Allocator { diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp index 83d22242e8629d..cebab59d0664eb 100644 --- a/aten/src/ATen/core/function_schema.cpp +++ b/aten/src/ATen/core/function_schema.cpp @@ -197,4 +197,411 @@ bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaAr return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes); } } + +std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { + // eventually this should look almost identical to python arg parser, but + // it is simpler for now to work directly on this schema + + out << schema.name(); + if (!schema.overload_name().empty()) { + out << "." << schema.overload_name(); + } + out << "("; + + bool seen_kwarg_only = false; + for (const auto i : c10::irange(schema.arguments().size())) { + if (i > 0) out << ", "; + if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) { + out << "*, "; + seen_kwarg_only = true; + } + out << schema.arguments()[i]; + } + + if(schema.is_vararg()) { + if(!schema.arguments().empty()) + out << ", "; + out << "..."; + } + + out << ") -> "; + + const auto& returns = schema.returns(); + + /* + * We should skip parenthesis if we return a single item and it's not varret, + * or we return nothing but varret. + * + * Need special handling for schema + * aten::items.str(Dict(str, t) self) -> (str,t)[] + * Even though this schema returns a single item, we need add parenthesis. + * The is necessary so the printed schema can be parsed by the C++ SchemaParser + * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly + * treat the return type as a tuple. An alternative is to enhance the Lexer + * to lookahead multiple tokens to accurately decide if the return type is + * a tuple. + */ + bool need_paren = !( + (returns.size() == 1 && !schema.is_varret()) || + (returns.empty() && schema.is_varret())); + + if (returns.size() == 1 && !schema.is_varret()) { + std::stringstream return_ss; + return_ss << returns.at(0); + auto return_str = return_ss.str(); + + // enclosing the single return item with parenthesis if the return type + // starts with a left parenthesis. + // + // There are 2 cases + // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. + // without the extra parenthesis, the c++ schem parser can not parse it. + // 2. something like '-> ((str, str))'. Need extra parenthesis so the return + // type is a single tuple rather than two strings. + // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about + // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15) + // also covers this case. + if (!return_str.empty() && return_str.front() == '(') { + need_paren = true; + } + } + + if (need_paren) { + out << "("; + } + for (const auto i : c10::irange(returns.size())) { + if (i > 0) { + out << ", "; + } + out << returns.at(i); + } + if (schema.is_varret()) { + if (!returns.empty()) { + out << ", "; + } + out << "..."; + } + if (need_paren) { + out << ")"; + } + return out; +} + +static size_t findFirstOutArg(const std::vector& args) { + // find the start of out args in the schema + for (const auto out_start_idx : c10::irange(args.size())) { + if (args.at(out_start_idx).is_out()) { + return out_start_idx; + } + } + return args.size(); +} + +bool Argument::isBackwardCompatibleWith( + const Argument& old, + std::ostream* why_not) const { + const Argument* lhs = this; + const Argument* rhs = &old; + if (!(lhs->name() == rhs->name() + && lhs->N() == rhs->N() + && (lhs->alias_info() == rhs->alias_info() + || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr + && *lhs->alias_info() == *rhs->alias_info())))) { + return false; + } + if (lhs->kwarg_only() && !rhs->kwarg_only()) { + return false; + } + if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) { + return false; + } + if (rhs->default_value().has_value() && + lhs->default_value() != rhs->default_value()) { + return false; + } + return true; +} + +bool Argument::isForwardCompatibleWith( + const Argument& old, + std::ostream* why_not) const { + const Argument* lhs = this; + const Argument* rhs = &old; + if (!(lhs->name() == rhs->name() + && lhs->N() == rhs->N() + && (lhs->alias_info() == rhs->alias_info() + || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr + && *lhs->alias_info() == *rhs->alias_info())))) { + return false; + } + if (lhs->kwarg_only() && !rhs->kwarg_only()) { + return false; + } + if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) { + return false; + } + if (rhs->default_value().has_value() && + lhs->default_value() != rhs->default_value()) { + return false; + } + if (lhs->default_value().has_value() && !rhs->default_value().has_value()) { + return false; + } + return true; +} + +std::string FunctionSchema::formatTypeMismatchMsg( + const Argument& expected, + const std::string& actual_type, + std::optional position, + std::optional value) const { + std::string position_str; + if (position) { + position_str = c10::str("Position: ", *position, "\n"); + } + std::string value_str; + if (value) { + value_str = c10::str("Value: ", *value, "\n"); + } + return c10::str( + name(), + "() ", + expected.formatTypeMismatchMsg(actual_type), + position_str, + value_str, + "Declaration: ", + *this); +} + +bool FunctionSchema::isBackwardCompatibleWith( + const FunctionSchema& old, + std::ostream* why_not) const { + if (!(name() == old.name() + && overload_name() == old.overload_name() + // we are conservative on is_vararg and is_varret, + // since they are only used by internal operators + && is_vararg() == old.is_vararg() + && is_varret() == old.is_varret() + && returns().size() == old.returns().size() + && arguments().size() >= old.arguments().size())) { + return false; + } + for (const auto i : c10::irange(returns().size())) { + // Backwards compatibility requires covariance on argument types + // (i.e. more generic), and contravariance on return types (i.e. + // more specific). + if (!old.returns().at(i).isBackwardCompatibleWith( + returns().at(i), + why_not)) { + return false; + } + } + + // we want to test both out and default args separately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + // make sure among the default args, they are backward compatible + for (const auto i : c10::irange(old_out_start_idx)) { + if (!arguments().at(i).isBackwardCompatibleWith( + old.arguments().at(i), why_not)) { + return false; + } + } + + // Validate that all new arguments provided has a default value + for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) { + if (!arguments().at(i).default_value()) { + if (why_not) { + *why_not + << "Function schema not backward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() + << " did not provide a default value."; + } + return false; + } + } + + // now compare the out args + for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { + return false; + } + } + + return true; +} + +bool FunctionSchema::isForwardCompatibleWith( + const FunctionSchema& old, + std::ostringstream& why_not) const { + if (!(name() == old.name() && + overload_name() == old.overload_name() + // we are conservative on is_vararg and is_varret, + // since they are only used by internal operators + && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() && + returns().size() == old.returns().size())) { + return false; + } + + // we want to test both out and default args separately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + if (old.arguments().size() - old_out_start_idx != + arguments().size() - new_out_start_idx) { + if (why_not) { + why_not << "Function schema should have the " + << "same number of out arguments"; + } + return false; + } + + // make sure among the default args, they are forward compatible + for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) { + if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) { + if (why_not) { + why_not + << "'" << arguments().at(i).name() << "'" + << " is not forward compatible with the older version of the schema"; + } + return false; + } + } + + // Validate that all new arguments provided has a default value + for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { + if (!arguments().at(i).default_value()) { + if (why_not) { + why_not + << "Function schema is not forward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() + << " did not provide a default value."; + } + return false; + } + + auto default_val = arguments().at(i).default_value().value(); + if (default_val.isList() || default_val.isGenericDict()) { + if (why_not) { + why_not + << "Function schema is not forward compatible since the new argument '" + << arguments().at(i).name() << "' of type " + << arguments().at(i).type()->str() << " has a container type " + << "as its default value."; + } + return false; + } + } + + // now compare the out args + for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isForwardCompatibleWith(old.arguments().at(i))) { + if (why_not) { + why_not << "Out argument '" + << "'" << arguments().at(i).name() + << " is not FC with the older version of the schema"; + } + return false; + } + } + + return true; +} + +std::string FunctionSchema::findErrorInKwargs(const std::vector& kwargs) const { + // First check if any of the kwargs are unknown, i.e. don't match the name of + // any argument in the schema. + for (const auto& kwarg : kwargs) { + if (!std::count_if( + arguments().begin(), + arguments().end(), + [&kwarg](const Argument& argument) { + return argument.name() == kwarg; + })) { + return c10::str( + "Unknown keyword argument '", + kwarg, + "' for operator '", + name(), + "'. Schema: ", + *this); + } + } + // If there are unconsumed kwargs but none of them were unknown, the first + // positional argument present in the kwargs is duplicated. + for (const auto& argument : arguments()) { + if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) { + AT_ASSERT(!argument.default_value()); + return c10::str( + "Argument '", + argument.name(), + "' specified both as positional and ", + "keyword argument. Schema: ", + *this); + } + } + return ""; +} + + +FunctionSchema FunctionSchema::cloneWithRemappedTypes( + const std::function type_map) const { + auto update_args = [&](const std::vector& args) { + std::vector new_args; + new_args.reserve(args.size()); + for(const Argument& arg : args) { + new_args.emplace_back(arg.cloneWithType(type_map(arg.type()))); + } + return new_args; + }; + return FunctionSchema( + name(), + overload_name(), + update_args(arguments()), + update_args(returns()), + is_vararg(), + is_varret()); +} + +// covariant subtyping of list of Arguments +static bool isSubtypeOfList( + ArrayRef child, + ArrayRef parent, + std::ostream* why_not) { + if (child.size() != parent.size()) { + return false; + } + for (const auto i : c10::irange(child.size())) { + const Argument& c = child[i]; + const Argument& p = parent[i]; + if (c.name() != p.name()) { + return false; + } + if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) { + return false; + } + } + return true; +} + +bool FunctionSchema::isSubtypeOf( + const FunctionSchema& rhs, + bool as_method, + std::ostream* why_not) const { + size_t start = as_method ? 1 : 0; + // functions are contravariant in arguments but covariant in returns + return isSubtypeOfList( + ArrayRef(rhs.arguments()).slice(start), + ArrayRef(arguments()).slice(start), + why_not) && + isSubtypeOfList(returns(), rhs.returns(), why_not); +} + } // namespace c10 diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index de5247e0eaabca..8dab896b1411d7 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -25,7 +25,7 @@ using AliasTypeSet = std::vector; bool operator==(const Argument& lhs, const Argument& rhs); -struct Argument { +struct TORCH_API Argument { Argument( std::string name = "", const TypePtr& type = nullptr, @@ -622,7 +622,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { return out; } -inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); +TORCH_API std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema); inline std::string toString(const FunctionSchema& schema) { std::ostringstream str; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index e81103ea928a17..a2fff1c130cb57 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -2,328 +2,8 @@ #include #include -// note: windows build doesn't find symbols in operator files unless -// this is a header file - namespace c10 { -inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { - // eventually this should look almost identical to python arg parser, but - // it is simpler for now to work directly on this schema - - out << schema.name(); - if (!schema.overload_name().empty()) { - out << "." << schema.overload_name(); - } - out << "("; - - bool seen_kwarg_only = false; - for (const auto i : c10::irange(schema.arguments().size())) { - if (i > 0) out << ", "; - if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) { - out << "*, "; - seen_kwarg_only = true; - } - out << schema.arguments()[i]; - } - - if(schema.is_vararg()) { - if(!schema.arguments().empty()) - out << ", "; - out << "..."; - } - - out << ") -> "; - - const auto& returns = schema.returns(); - - /* - * We should skip parenthesis if we return a single item and it's not varret, - * or we return nothing but varret. - * - * Need special handling for schema - * aten::items.str(Dict(str, t) self) -> (str,t)[] - * Even though this schema returns a single item, we need add parenthesis. - * The is necessary so the printed schema can be parsed by the C++ SchemaParser - * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly - * treat the return type as a tuple. An alternative is to enhance the Lexer - * to lookahead multiple tokens to accurately decide if the return type is - * a tuple. - */ - bool need_paren = !( - (returns.size() == 1 && !schema.is_varret()) || - (returns.empty() && schema.is_varret())); - - if (returns.size() == 1 && !schema.is_varret()) { - std::stringstream return_ss; - return_ss << returns.at(0); - auto return_str = return_ss.str(); - - // enclosing the single return item with parenthesis if the return type - // starts with a left parenthesis. - // - // There are 2 cases - // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'. - // without the extra parenthesis, the c++ schem parser can not parse it. - // 2. something like '-> ((str, str))'. Need extra parenthesis so the return - // type is a single tuple rather than two strings. - // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about - // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15) - // also covers this case. - if (!return_str.empty() && return_str.front() == '(') { - need_paren = true; - } - } - - if (need_paren) { - out << "("; - } - for (const auto i : c10::irange(returns.size())) { - if (i > 0) { - out << ", "; - } - out << returns.at(i); - } - if (schema.is_varret()) { - if (!returns.empty()) { - out << ", "; - } - out << "..."; - } - if (need_paren) { - out << ")"; - } - return out; -} - -inline size_t findFirstOutArg(const std::vector& args) { - // find the start of out args in the schema - for (const auto out_start_idx : c10::irange(args.size())) { - if (args.at(out_start_idx).is_out()) { - return out_start_idx; - } - } - return args.size(); -} - -inline bool Argument::isBackwardCompatibleWith( - const Argument& old, - std::ostream* why_not) const { - const Argument* lhs = this; - const Argument* rhs = &old; - if (!(lhs->name() == rhs->name() - && lhs->N() == rhs->N() - && (lhs->alias_info() == rhs->alias_info() - || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr - && *lhs->alias_info() == *rhs->alias_info())))) { - return false; - } - if (lhs->kwarg_only() && !rhs->kwarg_only()) { - return false; - } - if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) { - return false; - } - if (rhs->default_value().has_value() && - lhs->default_value() != rhs->default_value()) { - return false; - } - return true; -} - -inline bool Argument::isForwardCompatibleWith( - const Argument& old, - std::ostream* why_not) const { - const Argument* lhs = this; - const Argument* rhs = &old; - if (!(lhs->name() == rhs->name() - && lhs->N() == rhs->N() - && (lhs->alias_info() == rhs->alias_info() - || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr - && *lhs->alias_info() == *rhs->alias_info())))) { - return false; - } - if (lhs->kwarg_only() && !rhs->kwarg_only()) { - return false; - } - if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) { - return false; - } - if (rhs->default_value().has_value() && - lhs->default_value() != rhs->default_value()) { - return false; - } - if (lhs->default_value().has_value() && !rhs->default_value().has_value()) { - return false; - } - return true; -} - -inline std::string FunctionSchema::formatTypeMismatchMsg( - const Argument& expected, - const std::string& actual_type, - std::optional position, - std::optional value) const { - std::string position_str; - if (position) { - position_str = c10::str("Position: ", *position, "\n"); - } - std::string value_str; - if (value) { - value_str = c10::str("Value: ", *value, "\n"); - } - return c10::str( - name(), - "() ", - expected.formatTypeMismatchMsg(actual_type), - position_str, - value_str, - "Declaration: ", - *this); -} - -inline bool FunctionSchema::isBackwardCompatibleWith( - const FunctionSchema& old, - std::ostream* why_not) const { - if (!(name() == old.name() - && overload_name() == old.overload_name() - // we are conservative on is_vararg and is_varret, - // since they are only used by internal operators - && is_vararg() == old.is_vararg() - && is_varret() == old.is_varret() - && returns().size() == old.returns().size() - && arguments().size() >= old.arguments().size())) { - return false; - } - for (const auto i : c10::irange(returns().size())) { - // Backwards compatibility requires covariance on argument types - // (i.e. more generic), and contravariance on return types (i.e. - // more specific). - if (!old.returns().at(i).isBackwardCompatibleWith( - returns().at(i), - why_not)) { - return false; - } - } - - // we want to test both out and default args separately - size_t old_out_start_idx = findFirstOutArg(old.arguments()); - size_t new_out_start_idx = findFirstOutArg(arguments()); - - // make sure among the default args, they are backward compatible - for (const auto i : c10::irange(old_out_start_idx)) { - if (!arguments().at(i).isBackwardCompatibleWith( - old.arguments().at(i), why_not)) { - return false; - } - } - - // Validate that all new arguments provided has a default value - for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) { - if (!arguments().at(i).default_value()) { - if (why_not) { - *why_not - << "Function schema not backward compatible since the new argument '" - << arguments().at(i).name() << "' of type " - << arguments().at(i).type()->str() - << " did not provide a default value."; - } - return false; - } - } - - // now compare the out args - for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) { - if (!arguments() - .at(i - old_out_start_idx + new_out_start_idx) - .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { - return false; - } - } - - return true; -} - -inline bool FunctionSchema::isForwardCompatibleWith( - const FunctionSchema& old, - std::ostringstream& why_not) const { - if (!(name() == old.name() && - overload_name() == old.overload_name() - // we are conservative on is_vararg and is_varret, - // since they are only used by internal operators - && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() && - returns().size() == old.returns().size())) { - return false; - } - - // we want to test both out and default args separately - size_t old_out_start_idx = findFirstOutArg(old.arguments()); - size_t new_out_start_idx = findFirstOutArg(arguments()); - - if (old.arguments().size() - old_out_start_idx != - arguments().size() - new_out_start_idx) { - if (why_not) { - why_not << "Function schema should have the " - << "same number of out arguments"; - } - return false; - } - - // make sure among the default args, they are forward compatible - for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) { - if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) { - if (why_not) { - why_not - << "'" << arguments().at(i).name() << "'" - << " is not forward compatible with the older version of the schema"; - } - return false; - } - } - - // Validate that all new arguments provided has a default value - for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { - if (!arguments().at(i).default_value()) { - if (why_not) { - why_not - << "Function schema is not forward compatible since the new argument '" - << arguments().at(i).name() << "' of type " - << arguments().at(i).type()->str() - << " did not provide a default value."; - } - return false; - } - - auto default_val = arguments().at(i).default_value().value(); - if (default_val.isList() || default_val.isGenericDict()) { - if (why_not) { - why_not - << "Function schema is not forward compatible since the new argument '" - << arguments().at(i).name() << "' of type " - << arguments().at(i).type()->str() << " has a container type " - << "as its default value."; - } - return false; - } - } - - // now compare the out args - for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { - if (!arguments() - .at(i - old_out_start_idx + new_out_start_idx) - .isForwardCompatibleWith(old.arguments().at(i))) { - if (why_not) { - why_not << "Out argument '" - << "'" << arguments().at(i).name() - << " is not FC with the older version of the schema"; - } - return false; - } - } - - return true; -} - template inline void FunctionSchema::checkArg( const IValue& value, @@ -341,41 +21,6 @@ inline void FunctionSchema::checkArg( } } -inline std::string FunctionSchema::findErrorInKwargs(const std::vector& kwargs) const { - // First check if any of the kwargs are unknown, i.e. don't match the name of - // any argument in the schema. - for (const auto& kwarg : kwargs) { - if (!std::count_if( - arguments().begin(), - arguments().end(), - [&kwarg](const Argument& argument) { - return argument.name() == kwarg; - })) { - return c10::str( - "Unknown keyword argument '", - kwarg, - "' for operator '", - name(), - "'. Schema: ", - *this); - } - } - // If there are unconsumed kwargs but none of them were unknown, the first - // positional argument present in the kwargs is duplicated. - for (const auto& argument : arguments()) { - if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) { - AT_ASSERT(!argument.default_value()); - return c10::str( - "Argument '", - argument.name(), - "' specified both as positional and ", - "keyword argument. Schema: ", - *this); - } - } - return ""; -} - template inline void FunctionSchema::checkAndNormalizeInputs( std::vector& inputs, @@ -427,57 +72,4 @@ inline void FunctionSchema::checkAndNormalizeInputs( } } -inline FunctionSchema FunctionSchema::cloneWithRemappedTypes( - const std::function type_map) const { - auto update_args = [&](const std::vector& args) { - std::vector new_args; - new_args.reserve(args.size()); - for(const Argument& arg : args) { - new_args.emplace_back(arg.cloneWithType(type_map(arg.type()))); - } - return new_args; - }; - return FunctionSchema( - name(), - overload_name(), - update_args(arguments()), - update_args(returns()), - is_vararg(), - is_varret()); -} - -// covariant subtyping of list of Arguments -inline bool isSubtypeOfList( - ArrayRef child, - ArrayRef parent, - std::ostream* why_not) { - if (child.size() != parent.size()) { - return false; - } - for (const auto i : c10::irange(child.size())) { - const Argument& c = child[i]; - const Argument& p = parent[i]; - if (c.name() != p.name()) { - return false; - } - if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) { - return false; - } - } - return true; -} - -inline bool FunctionSchema::isSubtypeOf( - const FunctionSchema& rhs, - bool as_method, - std::ostream* why_not) const { - size_t start = as_method ? 1 : 0; - // functions are contravariant in arguments but covariant in returns - return isSubtypeOfList( - ArrayRef(rhs.arguments()).slice(start), - ArrayRef(arguments()).slice(start), - why_not) && - isSubtypeOfList(returns(), rhs.returns(), why_not); -} - } // namespace c10 diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 4f6abd66cb8878..38942031befcd6 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -228,6 +228,7 @@ namespace c10 { _(aten, is_autocast_cpu_enabled) \ _(aten, is_autocast_xla_enabled) \ _(aten, get_autocast_dtype) \ + _(aten, is_autocast_mps_enabled) \ FORALL_ATEN_BASE_SYMBOLS(_) \ _(onnx, Add) \ _(onnx, Concat) \ diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index a50e52155b54ae..4455d4c1177312 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -9,7 +9,7 @@ #endif namespace at::cpu { -bool is_cpu_support_avx2() { +bool is_avx2_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx2(); #else @@ -17,7 +17,7 @@ bool is_cpu_support_avx2() { #endif } -bool is_cpu_support_avx512() { +bool is_avx512_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq(); #else @@ -25,7 +25,7 @@ bool is_cpu_support_avx512() { #endif } -bool is_cpu_support_avx512_vnni() { +bool is_avx512_vnni_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni(); #else @@ -33,7 +33,15 @@ bool is_cpu_support_avx512_vnni() { #endif } -bool is_cpu_support_amx_tile() { +bool is_avx512_bf16_supported() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_x86_avx512bf16(); +#else + return false; +#endif +} + +bool is_amx_tile_supported() { #if !defined(__s390x__) && !defined(__powerpc__) return cpuinfo_initialize() && cpuinfo_has_x86_amx_tile(); #else @@ -42,7 +50,7 @@ bool is_cpu_support_amx_tile() { } bool init_amx() { - if (!is_cpu_support_amx_tile()) { + if (!is_amx_tile_supported()) { return false; } diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index 7498367fee4ea7..ad918dde7e0599 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -6,14 +6,17 @@ namespace at::cpu { -TORCH_API bool is_cpu_support_avx2(); -TORCH_API bool is_cpu_support_avx512(); +TORCH_API bool is_avx2_supported(); +TORCH_API bool is_avx512_supported(); // Detect if CPU support Vector Neural Network Instruction. -TORCH_API bool is_cpu_support_avx512_vnni(); +TORCH_API bool is_avx512_vnni_supported(); + +// Detect if CPU supports AVX512_BF16 ISA +TORCH_API bool is_avx512_bf16_supported(); // Detect if CPU support Advanced Matrix Extension. -TORCH_API bool is_cpu_support_amx_tile(); +TORCH_API bool is_amx_tile_supported(); // Enable the system to use AMX instructions. TORCH_API bool init_amx(); diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index 48d44dc42c33ce..e54440ed6eedd0 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -78,7 +78,7 @@ struct VecReduceAllSIMD { #endif // defined(CPU_CAPABILITY_AVX512) #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE) template struct VecReduceAllSIMD { static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { diff --git a/aten/src/ATen/cpu/vec/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h index a82a8ef1a69457..48b18793b079e7 100644 --- a/aten/src/ATen/cpu/vec/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -5,6 +5,10 @@ #elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* Clang-compatible compiler, targeting arm neon */ #include +#if defined(__ARM_FEATURE_SVE) +/* CLANG-compatible compiler, targeting ARM with SVE */ +#include +#endif #elif defined(_MSC_VER) /* Microsoft C/C++-compatible compiler */ #include @@ -17,6 +21,10 @@ #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ #include +#if defined(__ARM_FEATURE_SVE) +/* GCC-compatible compiler, targeting ARM with SVE */ +#include +#endif #if defined (MISSING_ARM_VLD1) #include #elif defined (MISSING_ARM_VST1) diff --git a/aten/src/ATen/cpu/vec/sve/sve_helper.h b/aten/src/ATen/cpu/vec/sve/sve_helper.h new file mode 100644 index 00000000000000..e511ebb52b2e90 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/sve_helper.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include + +#if defined(CPU_CAPABILITY_SVE) + +// Define the data type of VLS(vector-length specific). +typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); +typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8))); + +#define ptrue svptrue_b8() +#define ZERO_S8 svdup_n_s8(0) +#define ZERO_S16 svdup_n_s16(0) +#define ZERO_S32 svdup_n_s32(0) +#define ZERO_S64 svdup_n_s64(0) +#define ZERO_U8 svdup_n_u8(0) +#define ZERO_U16 svdup_n_u16(0) +#define ZERO_U32 svdup_n_u32(0) +#define ZERO_U64 svdup_n_u64(0) +#define ZERO_F16 svdup_n_f16(0.f) +#define ZERO_F32 svdup_n_f32(0.f) +#define ZERO_F64 svdup_n_f64(0.0) +#define ONE_S8 svdup_n_s8(1) +#define ONE_S16 svdup_n_s16(1) +#define ONE_S32 svdup_n_s32(1) +#define ONE_S64 svdup_n_s64(1) +#define ONE_U8 svdup_n_u8(1) +#define ONE_U16 svdup_n_u16(1) +#define ONE_U32 svdup_n_u32(1) +#define ONE_U64 svdup_n_u64(1) +#define ONE_F16 svdup_n_f16(1.f) +#define ONE_F32 svdup_n_f32(1.f) +#define ONE_F64 svdup_n_f64(1.0) +#define ALL_S8_TRUE_MASK svdup_n_s8(0xff) +#define ALL_S8_FALSE_MASK svdup_n_s8(0x0) +#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff) +#define ALL_S16_FALSE_MASK svdup_n_s16(0x0) +#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff) +#define ALL_S32_FALSE_MASK svdup_n_s32(0x0) +#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff) +#define ALL_S64_FALSE_MASK svdup_n_s64(0x0) +#define ALL_U8_TRUE_MASK svdup_n_u8(0x01) +#define ALL_U8_FALSE_MASK svdup_n_u8(0x00) +#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK) +#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK) +#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK) +#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK) +#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK) +#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK) + +#endif // defined(CPU_CAPABILITY_SVE) diff --git a/aten/src/ATen/cpu/vec/sve/vec_common_sve.h b/aten/src/ATen/cpu/vec/sve/vec_common_sve.h new file mode 100644 index 00000000000000..6f572e16a4c1fa --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_common_sve.h @@ -0,0 +1,176 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include + +#include +#include + +#if defined(CPU_CAPABILITY_SVE) +#include +#include +#include +#include +#endif + +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized cast(const Vectorized& src) { + return svreinterpret_f32_f64(src); +} + +template<> +inline Vectorized cast(const Vectorized& src) { + return svreinterpret_f64_f32(src); +} + +#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \ +template<> \ +inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_s##int_bit##_f##float_bit(src); \ +} \ +template<> \ +inline Vectorized cast(const Vectorized& src) { \ + return svreinterpret_f##float_bit##_s##int_bit(src); \ +} + +DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64) +DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64) +DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64) +DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32) +DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32) +DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline gather(const double* base_addr, const Vectorized& vindex_) { + svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svld1_gather_s64index_f64(ptrue, base_addr, vindex); +} + +template +std::enable_if_t> +inline gather(const float* base_addr, const Vectorized& vindex_) { + svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svld1_gather_s32index_f32(ptrue, base_addr, vindex); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const double* base_addr, + const Vectorized& vindex_, const Vectorized& mask_) { + svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), + ALL_S64_TRUE_MASK); + svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); + return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); +} + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const float* base_addr, + const Vectorized& vindex_, const Vectorized& mask_) { + svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), + ALL_S32_TRUE_MASK); + svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); + return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// Only works for inputs in the range: [-2^51, 2^51] +// From: https://stackoverflow.com/a/41148578 +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); + return svsub_s64_x(ptrue, + svreinterpret_s64_f64(x), + svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); +} + +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + return svcvt_s32_f32_x(ptrue, src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3} + // b = {b0, b1, b2, b3} + // group cols crossing lanes: + // return {a0, b0, a1, b1} + // {a2, b2, a3, b3} + return std::make_pair(Vectorized(svzip1_f64(a, b)), + Vectorized(svzip2_f64(a, b))); +} + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + return std::make_pair(Vectorized(svzip1_f32(a, b)), + Vectorized(svzip2_f32(a, b))); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1} + // b = {a2, b2, a3, b3} + // swap lanes: + // return {a0, a1, a2, a3} + // {b0, b1, b2, b3} + return std::make_pair(Vectorized(svuzp1_f64(a, b)), + Vectorized(svuzp2_f64(a, b))); +} + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // swap lanes: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + return std::make_pair(Vectorized(svuzp1_f32(a, b)), + Vectorized(svuzp2_f32(a, b))); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_double.h b/aten/src/ATen/cpu/vec/sve/vec_double.h new file mode 100644 index 00000000000000..911e69da90d4c9 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_double.h @@ -0,0 +1,505 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> class Vectorized { +private: + vls_float64_t values; +public: + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(double); + } + Vectorized() {} + Vectorized(svfloat64_t v) : values(v) {} + Vectorized(double val) { + values = svdup_n_f64(val); + } + template> + Vectorized(Args... vals) { + __at_align__ double buffer[size()] = { vals... }; + values = svld1_f64(ptrue, buffer); + } + operator svfloat64_t() const { + return values; + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), + ALL_S64_TRUE_MASK); + return svsel_f64(mask, b, a); + } + template + static Vectorized arange(double base = 0., step_t step = static_cast(1)) { + __at_align__ double buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f64(ptrue, buffer); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f64(svwhilelt_b64(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f64(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b64(0ull, count); + return svld1_f64(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f64(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b64(0ull, count); + svst1_f64(pg, reinterpret_cast(ptr), values); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int64_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f64(ptrue, values, ZERO_F64); + svst1_s64(ptrue, mask_array, svsel_s64(svbool_mask, + ALL_S64_TRUE_MASK, + ALL_S64_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f64(ptrue, values, ZERO_F64); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any(ptrue, svcmpuo_f64(ptrue, svsub_f64_x(ptrue, values, values), ZERO_F64)); + } + Vectorized map(double (*f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f64_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f64(NAN); + const auto nan_mask = svcmpuo_f64(ptrue, values, ZERO_F64); + const auto pi = svdup_n_f64(c10::pi); + + const auto neg_mask = svcmplt_f64(ptrue, values, ZERO_F64); + auto angle = svsel_f64(neg_mask, pi, ZERO_F64); + angle = svsel_f64(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return USE_SLEEF(Vectorized(Sleef_acosdx_u10sve(values)),map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF( Vectorized(Sleef_acoshdx_u10sve(values)),map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF(Vectorized(Sleef_asindx_u10sve(values)),map(std::asin)); + } + Vectorized atan() const { + return USE_SLEEF(Vectorized(Sleef_atandx_u10sve(values)),map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF(Vectorized(Sleef_atanhdx_u10sve(values)),map(std::atanh)); + } + Vectorized atan2(const Vectorized &b) const { + USE_SLEEF({return Vectorized(Sleef_atan2dx_u10sve(values, b));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized copysign(const Vectorized &sign) const { + USE_SLEEF( {return Vectorized(Sleef_copysigndx_sve(values, sign));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + } + ) + } + Vectorized erf() const { + return USE_SLEEF(Vectorized(Sleef_erfdx_u10sve(values)),map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF(Vectorized(Sleef_erfcdx_u15sve(values)),map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF(Vectorized(Sleef_expdx_u10sve(values)),map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF(Vectorized(Sleef_exp2dx_u10sve(values)),map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF(Vectorized(Sleef_expm1dx_u10sve(values)),map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + USE_SLEEF({return Vectorized(Sleef_fmoddx_sve(values, q));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + } + ) + } + Vectorized hypot(const Vectorized &b) const { + USE_SLEEF({return Vectorized(Sleef_hypotdx_u05sve(values, b));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + }) + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized &b) const { + USE_SLEEF( + { + return Vectorized(Sleef_nextafterfx_sve(values, b)); + }, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized log() const { + return USE_SLEEF(Vectorized(Sleef_logdx_u10sve(values)),map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF(Vectorized(Sleef_log2dx_u10sve(values)),map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF(Vectorized(Sleef_log10dx_u10sve(values)),map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF(Vectorized(Sleef_log1pdx_u10sve(values)),map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF( Vectorized(Sleef_sindx_u10sve(values)),map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF(Vectorized(Sleef_sinhdx_u10sve(values)),map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF(Vectorized(Sleef_cosdx_u10sve(values)),map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF( Vectorized(Sleef_coshdx_u10sve(values)),map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f64_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f64_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f64_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f64_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF( Vectorized(Sleef_tandx_u10sve(values)),map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF( Vectorized(Sleef_tanhdx_u10sve(values)),map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f64_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF( Vectorized(Sleef_lgammadx_u10sve(values)),map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f64_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f64_x(ptrue, values, ONE_F64); + } + Vectorized rsqrt() const { + return svdivr_f64_x(ptrue, svsqrt_f64_x(ptrue, values), ONE_F64); + } + Vectorized pow(const Vectorized &b) const { + USE_SLEEF( {return Vectorized(Sleef_powdx_u10sve(values, b));}, + { + __at_align__ double tmp[size()]; + __at_align__ double tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f64(ptrue, values, other); + return svsel_f64(mask, ALL_F64_TRUE_MASK, ALL_F64_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return svadd_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return svsub_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return svmul_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_f64_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return svmax_f64_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return svmin_f64_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return svmin_f64_x(ptrue, max, svmax_f64_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return svmin_f64_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return svmax_f64_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f64_s64(svand_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f64_s64(svorr_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f64_s64(sveor_s64_x(ptrue, svreinterpret_s64_f64(a), svreinterpret_s64_f64(b))); +} + +Vectorized inline Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized inline Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f64(ptrue, dst + i, svldnt1_f64(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b64(i, n); + svst1_f64(pg, dst + i, svldnt1_f64(pg, src + i)); + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return svmad_f64_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h new file mode 100644 index 00000000000000..4da7cc53710006 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -0,0 +1,570 @@ +#pragma once + +#include +#include +#include +#include +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +template <> class Vectorized { +private: + vls_float32_t values; +public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return VECTOR_WIDTH / sizeof(float); + } + Vectorized() {} + Vectorized(svfloat32_t v) : values(v) {} + Vectorized(float val) { + values = svdup_n_f32(val); + } + template> + Vectorized(Args... vals) { + __at_align__ float buffer[size()] = { vals... }; + values = svld1_f32(ptrue, buffer); + } + operator svfloat32_t() const { + return values; + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask_) { + svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), + ALL_S32_TRUE_MASK); + return svsel_f32(mask, b, a); + } + template + static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { + __at_align__ float buffer[size()]; + for (int64_t i = 0; i < size(); i++) { + buffer[i] = base + i * step; + } + return svld1_f32(ptrue, buffer); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + if (count == 0) { + return a; + } else if (count < size()) { + return svsel_f32(svwhilelt_b32(0ull, count), b, a); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_f32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_f32(pg, reinterpret_cast(ptr)); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + svst1_f32(ptrue, reinterpret_cast(ptr), values); + } else { + svbool_t pg = svwhilelt_b32(0ull, count); + svst1_f32(pg, reinterpret_cast(ptr), values); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int64_t zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + int64_t mask = 0; + __at_align__ int32_t mask_array[size()]; + + svbool_t svbool_mask = svcmpeq_f32(ptrue, values, ZERO_F32); + svst1_s32(ptrue, mask_array, svsel_s32(svbool_mask, + ALL_S32_TRUE_MASK, + ALL_S32_FALSE_MASK)); + for (int64_t i = 0; i < size(); ++i) { + if (mask_array[i]) mask |= (1ull << i); + } + return mask; + } + Vectorized isnan() const { + // NaN check + svbool_t mask = svcmpuo_f32(ptrue, values, ZERO_F32); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + bool has_inf_nan() const { + return svptest_any(ptrue, svcmpuo_f32(ptrue, svsub_f32_x(ptrue, values, values), ZERO_F32)); + } + Vectorized map(float (*f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return svabs_f32_x(ptrue, values); + } + Vectorized angle() const { + const auto nan_vec = svdup_n_f32(NAN); + const auto nan_mask = svcmpuo_f32(ptrue, values, ZERO_F32); + const auto pi = svdup_n_f32(c10::pi); + + const auto neg_mask = svcmplt_f32(ptrue, values, ZERO_F32); + auto angle = svsel_f32(neg_mask, pi, ZERO_F32); + angle = svsel_f32(nan_mask, nan_vec, angle); + return angle; + } + Vectorized real() const { + return values; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return values; + } + Vectorized acos() const { + return USE_SLEEF(Vectorized(Sleef_acosfx_u10sve(values)),map(std::acos)); + } + Vectorized acosh() const { + return USE_SLEEF(Vectorized(Sleef_acoshfx_u10sve(values)),map(std::acosh)); + } + Vectorized asin() const { + return USE_SLEEF(Vectorized(Sleef_asinfx_u10sve(values)),map(std::asin)); + } + Vectorized atan() const { + return USE_SLEEF(Vectorized(Sleef_atanfx_u10sve(values)),map(std::atan)); + } + Vectorized atanh() const { + return USE_SLEEF(Vectorized(Sleef_atanhfx_u10sve(values)),map(std::atanh)); + } + Vectorized atan2(const Vectorized &b) const { + USE_SLEEF({return Vectorized(Sleef_atan2fx_u10sve(values, b));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++){ + tmp[i] = std::atan2(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized copysign(const Vectorized &sign) const { + + USE_SLEEF({return Vectorized(Sleef_copysignfx_sve(values, sign));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; + store(tmp); + sign.store(tmp_sign); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::copysign(tmp[i], tmp_sign[i]); + } + return loadu(tmp); + }) + } + Vectorized erf() const { + return USE_SLEEF(Vectorized(Sleef_erffx_u10sve(values)),map(std::erf)); + } + Vectorized erfc() const { + return USE_SLEEF(Vectorized(Sleef_erfcfx_u15sve(values)),map(std::erfc)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return USE_SLEEF(Vectorized(Sleef_expfx_u10sve(values)),map(std::exp)); + } + Vectorized exp2() const { + return USE_SLEEF(Vectorized(Sleef_exp2fx_u10sve(values)),map(std::exp2)); + } + Vectorized expm1() const { + return USE_SLEEF(Vectorized(Sleef_expm1fx_u10sve(values)),map(std::expm1)); + } + Vectorized exp_u20() const { + return exp(); + } + Vectorized fmod(const Vectorized& q) const { + USE_SLEEF({return Vectorized(Sleef_fmodfx_sve(values, q));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; + store(tmp); + q.store(tmp_q); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::fmod(tmp[i], tmp_q[i]); + } + return loadu(tmp); + }) + } + Vectorized hypot(const Vectorized &b) const { + USE_SLEEF( {return Vectorized(Sleef_hypotfx_u05sve(values, b));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::hypot(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized nextafter(const Vectorized &b) const { + USE_SLEEF( + { + return Vectorized(Sleef_nextafterfx_sve(values, b)); + }, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); ++i) { + tmp[i] = std::nextafter(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + Vectorized log() const { + return USE_SLEEF(Vectorized(Sleef_logfx_u10sve(values)),map(std::log)); + } + Vectorized log2() const { + return USE_SLEEF(Vectorized(Sleef_log2fx_u10sve(values)),map(std::log2)); + } + Vectorized log10() const { + return USE_SLEEF(Vectorized(Sleef_log10fx_u10sve(values)),map(std::log10)); + } + Vectorized log1p() const { + return USE_SLEEF(Vectorized(Sleef_log1pfx_u10sve(values)),map(std::log1p)); + } + Vectorized frac() const; + Vectorized sin() const { + return USE_SLEEF(Vectorized(Sleef_sinfx_u10sve(values)),map(std::sin)); + } + Vectorized sinh() const { + return USE_SLEEF(Vectorized(Sleef_sinhfx_u10sve(values)),map(std::sinh)); + } + Vectorized cos() const { + return USE_SLEEF(Vectorized(Sleef_cosfx_u10sve(values)),map(std::cos)); + } + Vectorized cosh() const { + return USE_SLEEF(Vectorized(Sleef_coshfx_u10sve(values)),map(std::cosh)); + } + Vectorized ceil() const { + return svrintp_f32_x(ptrue, values); + } + Vectorized floor() const { + return svrintm_f32_x(ptrue, values); + } + Vectorized neg() const { + return svneg_f32_x(ptrue, values); + } + Vectorized round() const { + return svrinti_f32_x(ptrue, values); + } + Vectorized tan() const { + return USE_SLEEF(Vectorized(Sleef_tanfx_u10sve(values)),map(std::tan)); + } + Vectorized tanh() const { + return USE_SLEEF(Vectorized(Sleef_tanhfx_u10sve(values)),map(std::tanh)); + } + Vectorized trunc() const { + return svrintz_f32_x(ptrue, values); + } + Vectorized lgamma() const { + return USE_SLEEF(Vectorized(Sleef_lgammafx_u10sve(values)),map(std::lgamma)); + } + Vectorized sqrt() const { + return svsqrt_f32_x(ptrue, values); + } + Vectorized reciprocal() const { + return svdivr_f32_x(ptrue, values, ONE_F32); + } + Vectorized rsqrt() const { + return svdivr_f32_x(ptrue, svsqrt_f32_x(ptrue, values), ONE_F32); + } + Vectorized pow(const Vectorized &b) const { + USE_SLEEF( {return Vectorized(Sleef_powfx_u10sve(values, b));}, + { + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; + store(tmp); + b.store(tmp_b); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = std::pow(tmp[i], tmp_b[i]); + } + return loadu(tmp); + } + ) + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + svbool_t mask = svcmpeq_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator!=(const Vectorized& other) const { + svbool_t mask = svcmpne_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<(const Vectorized& other) const { + svbool_t mask = svcmplt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator<=(const Vectorized& other) const { + svbool_t mask = svcmple_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>(const Vectorized& other) const { + svbool_t mask = svcmpgt_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized operator>=(const Vectorized& other) const { + svbool_t mask = svcmpge_f32(ptrue, values, other); + return svsel_f32(mask, ALL_F32_TRUE_MASK, ALL_F32_FALSE_MASK); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return svadd_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return svsub_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return svmul_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_f32_x(ptrue, a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized inline Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return svmax_f32_x(ptrue, a, b); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return svmin_f32_x(ptrue, a, b); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return svmin_f32_x(ptrue, max, svmax_f32_x(ptrue, min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return svmin_f32_x(ptrue, max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return svmax_f32_x(ptrue, min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f32_s32(svand_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f32_s32(svorr_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return svreinterpret_f32_s32(sveor_s32_x(ptrue, svreinterpret_s32_f32(a), svreinterpret_s32_f32(b))); +} + +Vectorized inline Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized inline Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svst1_f32(ptrue, dst + i, svldnt1_f32(ptrue, src + i)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + svbool_t pg = svwhilelt_b32(i, n); + svst1_f32(pg, dst + i, svldnt1_f32(pg, src + i)); + } +} + +template <> +inline void convert(const float *src, at::Half *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), + ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svuzp1_f16(svcvt_f16_f32_x(ptrue, svldnt1_f32(pg_32, src + i)), + ZERO_F16); + svst1_f16(pg_16, reinterpret_cast(dst) + i, src_vec); + } +} + +template <> +inline void convert(const at::Half *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_16 = svwhilelt_b16(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_16 = svwhilelt_b16(i, n); + pg_32 = svwhilelt_b32(i, n); + svfloat16_t src_vec = svzip1_f16(svldnt1_f16(pg_16, reinterpret_cast(src) + i), + ZERO_F16); + svst1_f32(pg_32, dst + i, svcvt_f32_f16_x(ptrue, src_vec)); + } +} + +template <> +inline void convert(const bool *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_f32(pg_32, dst + i, svsel_f32(mask, ONE_F32, ZERO_F32)); + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return svmad_f32_x(ptrue, a, b, c); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_int.h b/aten/src/ATen/cpu/vec/sve/vec_int.h new file mode 100644 index 00000000000000..6a081bd00a7514 --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_int.h @@ -0,0 +1,410 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +#define VEC_INT_SVE_TEMPLATE(vl, bit) \ +template <> class Vectorized { \ +private: \ + vls_int##bit##_t values; \ +public: \ + using value_type = int##bit##_t; \ + using size_type = int; \ + static constexpr size_type size() { \ + return vl; \ + } \ + Vectorized() {} \ + Vectorized(svint##bit##_t v) : values(v) {} \ + Vectorized(int##bit##_t val) { \ + values = svdup_n_s##bit(val); \ + } \ + template> \ + Vectorized(Args... vals) { \ + __at_align__ int##bit##_t buffer[size()] = { vals... }; \ + values = svld1_s##bit(ptrue, buffer); \ + } \ + operator svint##bit##_t() const { \ + return values; \ + } \ + static Vectorized blendv(const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& mask_) { \ + svbool_t mask = svcmpeq_s##bit(ptrue, mask_, ALL_S##bit##_TRUE_MASK); \ + return svsel_s##bit(mask, b, a); \ + } \ + /* step sometimes requires a higher precision type (e.g., T=int, step_t=double) */ \ + template \ + static Vectorized arange(int##bit##_t base = 0, step_t step = static_cast(1)) { \ + __at_align__ int##bit##_t buffer[size()]; \ + for (int64_t i = 0; i < size(); i++) { \ + buffer[i] = base + i * step; \ + } \ + return svld1_s##bit(ptrue, buffer); \ + } \ + static Vectorized set(const Vectorized& a, \ + const Vectorized& b, \ + int##bit##_t count = size()) { \ + if (count == 0) { \ + return a; \ + } else if (count < size()) { \ + return svsel_s##bit(svwhilelt_b##bit(0ull, count), b, a); \ + } \ + return b; \ + } \ + static Vectorized loadu(const void* ptr, int64_t count = size()) { \ + if (count == size()) \ + return svld1_s##bit(ptrue, reinterpret_cast(ptr)); \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + return svld1_s##bit(pg, reinterpret_cast(ptr)); \ + } \ + void store(void* ptr, int64_t count = size()) const { \ + if (count == size()) { \ + svst1_s##bit(ptrue, reinterpret_cast(ptr), values); \ + } else { \ + svbool_t pg = svwhilelt_b##bit(0ull, count); \ + svst1_s##bit(pg, reinterpret_cast(ptr), values); \ + } \ + } \ + const int##bit##_t& operator[](int idx) const = delete; \ + int##bit##_t& operator[](int idx) = delete; \ + Vectorized abs() const { \ + return svabs_s##bit##_x(ptrue, values); \ + } \ + Vectorized real() const { \ + return values; \ + } \ + Vectorized imag() const { \ + return svdup_n_s##bit(0); \ + } \ + Vectorized conj() const { \ + return values; \ + } \ + Vectorized frac() const; \ + Vectorized neg() const { \ + return svneg_s##bit##_x(ptrue, values); \ + } \ + Vectorized operator==(const Vectorized& other) const { \ + svbool_t mask = svcmpeq_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator!=(const Vectorized& other) const { \ + svbool_t mask = svcmpne_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<(const Vectorized& other) const { \ + svbool_t mask = svcmplt_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator<=(const Vectorized& other) const { \ + svbool_t mask = svcmple_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>(const Vectorized& other) const { \ + svbool_t mask = svcmpgt_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized operator>=(const Vectorized& other) const { \ + svbool_t mask = svcmpge_s##bit(ptrue, values, other); \ + return svsel_s##bit(mask, ALL_S##bit##_TRUE_MASK, ALL_S##bit##_FALSE_MASK); \ + } \ + Vectorized eq(const Vectorized& other) const; \ + Vectorized ne(const Vectorized& other) const; \ + Vectorized gt(const Vectorized& other) const; \ + Vectorized ge(const Vectorized& other) const; \ + Vectorized lt(const Vectorized& other) const; \ + Vectorized le(const Vectorized& other) const; \ +}; \ +template <> \ +Vectorized inline operator+(const Vectorized& a, \ + const Vectorized& b) { \ + return svadd_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator-(const Vectorized& a, \ + const Vectorized& b) { \ + return svsub_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator*(const Vectorized& a, \ + const Vectorized& b) { \ + return svmul_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline maximum(const Vectorized& a, \ + const Vectorized& b) { \ + return svmax_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline minimum(const Vectorized& a, \ + const Vectorized& b) { \ + return svmin_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline clamp(const Vectorized& a, \ + const Vectorized& min, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, svmax_s##bit##_x(ptrue, min, a)); \ +} \ +template <> \ +Vectorized inline clamp_max(const Vectorized& a, \ + const Vectorized& max) { \ + return svmin_s##bit##_x(ptrue, max, a); \ +} \ +template <> \ +Vectorized inline clamp_min(const Vectorized& a, \ + const Vectorized& min) { \ + return svmax_s##bit##_x(ptrue, min, a); \ +} \ +template <> \ +Vectorized inline operator&(const Vectorized& a, \ + const Vectorized& b) { \ + return svand_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator|(const Vectorized& a, \ + const Vectorized& b) { \ + return svorr_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +Vectorized inline operator^(const Vectorized& a, \ + const Vectorized& b) { \ + return sveor_s##bit##_x(ptrue, a, b); \ +} \ +template <> \ +inline Vectorized operator~(const Vectorized& a) { \ + return sveor_s##bit##_x(ptrue, a, svdup_n_s##bit(-1)); \ +} \ +Vectorized inline Vectorized::eq(const Vectorized& other) const { \ + return (*this == other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::ne(const Vectorized& other) const { \ + return (*this != other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::gt(const Vectorized& other) const { \ + return (*this > other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::ge(const Vectorized& other) const { \ + return (*this >= other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::lt(const Vectorized& other) const { \ + return (*this < other) & Vectorized(1); \ +} \ +Vectorized inline Vectorized::le(const Vectorized& other) const { \ + return (*this <= other) & Vectorized(1); \ +} + +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int64_t), 64) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int32_t), 32) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int16_t), 16) +VEC_INT_SVE_TEMPLATE(VECTOR_WIDTH / sizeof(int8_t), 8) + +template +Vectorized inline intdiv_nosve(const Vectorized& a, const Vectorized& b) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] /= values_b[i]; + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_s64_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return svdiv_s32_x(ptrue, a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return intdiv_nosve(a, b); +} + +template <> +inline void convert(const int32_t *src, int64_t *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svst1_s64(pg_64, dst + i, svunpklo_s64(svldnt1_s32(pg_32, src + i))); + } +} + +template <> +inline void convert(const int64_t *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_32 = svwhilelt_b32(i, n); + pg_64 = svwhilelt_b64(i, n); + svint64_t src_vec_s64 = svldnt1_s64(pg_64, src + i); + svfloat32_t src_vec_f32 = svuzp1_f32(svcvt_f32_s64_x(pg_64, src_vec_s64), ZERO_F32); + svst1_f32(pg_32, dst + i, src_vec_f32); + } +} + +template <> +inline void convert(const int32_t *src, float *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b32(i, n); + svint32_t src_vec = svldnt1_s32(pg, src + i); + svst1_f32(pg, dst + i, svcvt_f32_s32_x(pg, src_vec)); + } +} + +template <> +inline void convert(const bool *src, int64_t *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_64 = svwhilelt_b64(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_64 = svwhilelt_b64(i, n); + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint64_t src_vec_u64 = svunpklo_u64(svunpklo_u32(svunpklo_u16(src_vec_u8))); + svbool_t mask = svcmpne_u64(pg_64, src_vec_u64, ZERO_U64); + svst1_s64(pg_64, dst + i, svsel_s64(mask, ONE_S64, ZERO_S64)); + } +} + +template <> +inline void convert(const bool *src, int32_t *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg_8 = svwhilelt_b8(0ull, Vectorized::size()); + svbool_t pg_32 = svwhilelt_b32(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg_8 = svwhilelt_b8(i, n); + pg_32 = svwhilelt_b32(i, n); + svuint8_t src_vec_u8 = svldnt1_u8(pg_8, reinterpret_cast(src) + i); + svuint32_t src_vec_u32 = svunpklo_u32(svunpklo_u16(src_vec_u8)); + svbool_t mask = svcmpne_u32(pg_32, src_vec_u32, ZERO_U32); + svst1_s32(pg_32, dst + i, svsel_s32(mask, ONE_S32, ZERO_S32)); + } +} + +template <> +inline void convert(const uint8_t *src, bool *dst, int64_t n) { + const int64_t fraction = n % Vectorized::size(); + svbool_t pg = svwhilelt_b8(0ull, Vectorized::size()); +#pragma unroll + for (int64_t i = 0; i < n - fraction; i += Vectorized::size()) { + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8(pg, reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +#pragma unroll + for (int64_t i = n - fraction; i < n; i += Vectorized::size()) { + pg = svwhilelt_b8(i, n); + svbool_t mask = svcmpne_u8(pg, svldnt1_u8(pg, src + i), ZERO_U8); + svst1_u8(pg, reinterpret_cast(dst) + i, + svsel_u8(mask, ALL_U8_TRUE_MASK, ALL_U8_FALSE_MASK)); + } +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return svlsl_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s64_x(ptrue, a, svreinterpret_u64_s64(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s32_x(ptrue, a, svreinterpret_u32_s32(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s16_x(ptrue, a, svreinterpret_u16_s16(b)); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return svasr_s8_x(ptrue, a, svreinterpret_u8_s8(b)); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/sve/vec_qint.h b/aten/src/ATen/cpu/vec/sve/vec_qint.h new file mode 100644 index 00000000000000..7c49c041ddf2ff --- /dev/null +++ b/aten/src/ATen/cpu/vec/sve/vec_qint.h @@ -0,0 +1,567 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with SVE] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vectorized::float_num_vecs +// iterations. + +namespace at { +namespace vec { +// Note [CPU_CAPABILITY namespace] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// This header, and all of its subheaders, will be compiled with +// different architecture flags for each supported set of vector +// intrinsics. So we need to make sure they aren't inadvertently +// linked together. We do this by declaring objects in an `inline +// namespace` which changes the name mangling, but can still be +// accessed as `at::vec`. +inline namespace CPU_CAPABILITY { + +#if defined(CPU_CAPABILITY_SVE) + +// NOTE: These are low-performance implementations that we fall back on +// if we are not building with SVE. This may not be an issue, because +// currently for quantization we assume the user has at least SVE +// installed, so these can simply act as a reference implementation. +// +// If in the future we relax this requirement (SVE+), we should probably +// revisit these implementations + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + using size_type = int; + static constexpr size_type size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / Vectorized::size(); + } + + static constexpr int int_num_vecs() { + return size() / Vectorized::size(); + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = + at::native::dequantize_val(tmp_scale[j], tmp_zero_point[j], T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + float_vec_return_type rv; + float tmp_scale[Vectorized::size()]; + float tmp_zero_point[Vectorized::size()]; + scale.store(tmp_scale); + zero_point.store(tmp_zero_point); + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[Vectorized::size()]; + for (int j = 0; j < Vectorized::size(); ++j) { + tmp_vals[j] = + at::native::dequantize_val(tmp_scale[j], tmp_zero_point[j], T(vals[Vectorized::size() * i + j])); + } + rv[i] = Vectorized::loadu(tmp_vals); + } + return rv; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + VECTOR_WIDTH / 4>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_s32(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b32(0ull, count); + return svld1_s32(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * Vectorized::size()], Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type)); + return loadu(tmp_values); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * Vectorized::size()], Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + VECTOR_WIDTH>(ptr) {} +#if 1 + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ value_type tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, reinterpret_cast(ptr), count * sizeof(value_type)); + return loadu(tmp_values); + } +#else + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return svld1_u8(ptrue, reinterpret_cast(ptr)); + svbool_t pg = svwhilelt_b8(0ull, count); + return svld1_u8(pg, reinterpret_cast(ptr)); + } +#endif + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array::size()> float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * Vectorized::size()], Vectorized::size()); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + Vectorized::size() * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_SVE) + +}}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 6f7abf193b77c2..68367b81bd8a03 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -7,9 +7,13 @@ #include #if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)) -#include +#if defined(CPU_CAPABILITY_SVE256) +#include +#else #include #include +#endif +#include #include #include #include @@ -314,6 +318,17 @@ inline Vectorized flip(const Vectorized & v) { return flip8(v); } +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m256i* self_ = reinterpret_cast(self.as_bytes()); + const __m256i* other_ = reinterpret_cast(other.as_bytes()); + __m256i out = _mm256_and_si256(*self_, *other_); + Vectorized ret; + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + #endif // (defined(CPU_CAPABILITY_AVX2) }} // namepsace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index e567c1925be840..12c11abb748dea 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -1097,7 +1097,7 @@ inline Vectorized convert_float_##name(const Vectorized& a, const V return Vectorized::loadu(arr2); \ } CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16); -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE256) inline std::tuple, Vectorized> convert_half_float(const Vectorized& a) { static_assert(Vectorized::size() == 2 * Vectorized::size()); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h index 7242e6facacaef..b0f109fc875026 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h @@ -43,6 +43,26 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { @@ -52,6 +72,24 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + template <> inline Vectorized convert_to_fp_of_same_size( const Vectorized& src); diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index 0d0fe99252a7d0..dab1790b26ab01 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -636,6 +636,21 @@ inline void transpose_mxn( _mm256_storeu_ps(&dst[7 * ld_dst], th); } +template<> +inline void transpose_mxn( + const float* src, + int64_t ld_src, + float* dst, + int64_t ld_dst) { + transpose_mxn( + src , ld_src, dst, ld_dst); + transpose_mxn( + src + 8, ld_src, dst + 8 * ld_dst, ld_dst); + transpose_mxn( + src + 8 * ld_src, ld_src, dst + 8, ld_dst); + transpose_mxn( + src + 8 * ld_src + 8, ld_src, dst + 8 * ld_dst + 8, ld_dst); +} #endif }} // namespace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_mask.h b/aten/src/ATen/cpu/vec/vec256/vec256_mask.h index dd6a8c52d82655..3460abe17e159d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_mask.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_mask.h @@ -9,49 +9,218 @@ inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + if constexpr (std::is_same_v) { + result[i] = Vectorized( + _mm256_maskload_ps(ptr + i * Vectorized::size(), int_mask)); + } else { + result[i] = Vectorized( + _mm256_maskload_epi32(ptr + i * Vectorized::size(), int_mask)); + } + } + return result; + } +}; + template struct VecMaskLoad< T, - 1, + 2, mask_t, 1, typename std::enable_if_t< - std::is_same_v || std::is_same_v || - std::is_same_v, - void>> { - static inline VectorizedN apply( + std::is_same_v || std::is_same_v>> { + static inline VectorizedN apply( const T* ptr, const VecMask& vec_mask) { - auto int_mask = vec_mask.template cast()[0]; - if constexpr (std::is_same_v) { - return Vectorized(_mm256_maskload_ps(ptr, int_mask)); + auto int64_mask = vec_mask.template cast(); + auto result = at::vec::VectorizedN(); + if constexpr (std::is_same_v) { + result[0] = _mm256_maskload_pd(ptr, int64_mask[0]); + result[1] = _mm256_maskload_pd( + ptr + at::vec::Vectorized::size(), int64_mask[1]); } else { - return Vectorized(_mm256_maskload_epi32(ptr, int_mask)); + result[0] = _mm256_maskload_epi64( + reinterpret_cast(ptr), int64_mask[0]); + result[1] = _mm256_maskload_epi64( + reinterpret_cast( + ptr + at::vec::Vectorized::size()), + int64_mask[1]); } + return result; } }; // TODO: add specialization of VecMaskLoad for bfloat16/half and int8/uint8 -template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm256_castsi256_ps(vec_mask[0])); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_ps(vec_mask[i]); + } + return result; } }; -template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm256_castps_si256(vec_mask[0])); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castps_si256(vec_mask[i]); + } + return result; } }; -template -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - auto int_vec = convert(VectorizedN(vec_mask)); - return VecMask(int_vec).cast(); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castpd_si256(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm256_castsi256_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); + } +}; + +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); + } +}; +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); } }; @@ -71,6 +240,42 @@ inline bool VecMask::all_masked() const { return mask == 0xff; } +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = all_zero && (_mm256_testz_si256(vec_mask[i], vec_mask[i]) > 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 4) { + return _mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[j])) & + (1 << (i - j * 4)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && + (_mm256_movemask_pd(_mm256_castsi256_pd(vec_mask[i])) == 0x0f); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ T, N, return_type, method, args_def, args) \ template <> \ diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 8659fbb20f9665..3430e654d7f1f8 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -843,7 +843,7 @@ Vectorized inline maximum(const Vectorized& a, const V return a.maximum(b); } -#else +#elif !defined(CPU_CAPABILITY_SVE256) // NOTE: These are low-performance implementations that we fall back on // if we are not building with AVX2. This may not be an issue, because diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h index c7fa23b23a6074..d593d184c3190a 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512.h @@ -274,6 +274,18 @@ inline Vectorized flip(const Vectorized & v) { return flip8(v); } +inline Vectorized operator&&( + const Vectorized& self, + const Vectorized& other) { + const __m512i* self_ = reinterpret_cast(self.as_bytes()); + const __m512i* other_ = reinterpret_cast(other.as_bytes()); + __m512i out = _mm512_and_si512(*self_, *other_); + Vectorized ret; + // We do not have a constructer that takes __m512i, so we need to memcpy + std::memcpy(ret, &out, ret.size() * sizeof(bool)); + return ret; +} + #endif // defined(CPU_CAPABILITY_AVX512) }}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h index d2006fcf1c4755..dcdb682c56208d 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -1348,41 +1348,20 @@ static inline void _transpose_mxn_half_32_32(__m512i r[], __m512i d[]) { // Code referred to FBGEMM: // https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6 template<> -inline void transpose_mxn( - const BFloat16* src, - int64_t ld_src, - BFloat16* dst, - int64_t ld_dst) { - // Load from memory - __m512i r[32]; -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src)); - } - - __m512i d[32]; - _transpose_mxn_half_32_32(r, d); - - // Store to dst -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - _mm512_storeu_si512(dst + i* ld_dst, d[i]); - } -} - -template ::value && ((M < 32 && M != 16) || (N < 32 && N != 16)), int> = 0> -inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) { +inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst, int M, int N) { // load from src - __mmask32 src_mask = (1 << N) - 1; + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); __m512i r[32]; int i; - for (i = 0; i < M; ++i) { - r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } } for (; i < 32; ++i) { r[i] = _mm512_setzero_si512(); @@ -1392,48 +1371,39 @@ inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, in _transpose_mxn_half_32_32(r, d); // store to dst - __mmask32 dst_mask = (1 << M) - 1; - for (i = 0; i < N; ++i) { - _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } } } -template<> -inline void transpose_mxn( - const Half* src, - int64_t ld_src, - Half* dst, - int64_t ld_dst) { - // Load from memory - __m512i r[32]; -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - r[i] = _mm512_loadu_si512(reinterpret_cast(src + i* ld_src)); - } - - __m512i d[32]; - _transpose_mxn_half_32_32(r, d); - - // Store to dst -#ifndef __msvc_cl__ -#pragma unroll(32) -#endif - for (int i = 0; i < 32; ++i) { - _mm512_storeu_si512(dst + i* ld_dst, d[i]); - } +template ::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0> +inline void transpose_mxn(const BFloat16* src, int64_t ld_src, BFloat16* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); } -template ::value && ((M < 32 && M != 16) || (N < 32 && N != 16)), int> = 0> -inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { +template<> +inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst, int M, int N) { + TORCH_CHECK(M <= 32 && N <= 32, "transpose_mxn expects M, N <= 32."); // load from src - __mmask32 src_mask = (1 << N) - 1; __m512i r[32]; int i; - for (i = 0; i < M; ++i) { - r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + if (N == 32) { + for (i = 0; i < M; ++i) { + r[i] = _mm512_loadu_si512(&src[i * ld_src]); + } + } else { + __mmask32 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + r[i] = _mm512_maskz_loadu_epi16(src_mask, &src[i * ld_src]); + } } for (; i < 32; ++i) { r[i] = _mm512_setzero_si512(); @@ -1443,12 +1413,24 @@ inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld _transpose_mxn_half_32_32(r, d); // store to dst - __mmask32 dst_mask = (1 << M) - 1; - for (i = 0; i < N; ++i) { - _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + if (M == 32) { + for (i = 0; i < N; ++i) { + _mm512_storeu_si512(&dst[i * ld_dst], d[i]); + } + } else { + __mmask32 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_epi16(&dst[i * ld_dst], dst_mask, d[i]); + } } } +template ::value && ((M <= 32 && M != 16) || (N <= 32 && N != 16)), int> = 0> +inline void transpose_mxn(const Half* src, int64_t ld_src, Half* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + template <> class Vectorized: public Vectorized16 { public: diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h index fcdfa3d5934c97..78c7045fb30e3f 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h @@ -43,6 +43,26 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_bfloat16(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_bfloat16_float(src[0]); + return result; + } +}; + template <> struct VecConvert { static inline VectorizedN apply(const VectorizedN& src) { @@ -52,6 +72,24 @@ struct VecConvert { } }; +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + result[0] = convert_float_half(src[0], src[1]); + return result; + } +}; + +template <> +struct VecConvert { + static inline VectorizedN apply(const VectorizedN& src) { + VectorizedN result; + std::tie(result[0], result[1]) = convert_half_float(src[0]); + return result; + } +}; + template <> struct VecConvert { static inline VectorizedN apply( diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h index 289f927a10689b..4e21eae91cb240 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -582,15 +582,20 @@ Vectorized inline fmsub(const Vectorized& a, const Vectorized::value && M <= 16 && N <= 16, int> = 0> -inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { +inline void transpose_mxn_16x16(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { + TORCH_CHECK(M <= 16 && N <= 16, "transpose_mxn expects M, N <= 16."); // load from src to registers - __mmask16 src_mask = (1 << N) - 1; __m512 input[16]; int i; - for (i = 0; i < M; ++i) { - input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + if (N == 16) { + for (i = 0; i < M; ++i) { + input[i] = _mm512_loadu_ps(&src[i * ld_src]); + } + } else { + __mmask16 src_mask = (1 << N) - 1; + for (i = 0; i < M; ++i) { + input[i] = _mm512_maskz_loadu_ps(src_mask, &src[i * ld_src]); + } } for (; i < 16; ++i) { // Not really needed but to avoid uninitialized variable warning. @@ -640,16 +645,62 @@ inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t _mm512_shuffle_f32x4(input[8 * i + 3], input[8 * i + 7], 0xdd); } - // store from registers to dst - __mmask16 dst_mask = (1 << M) - 1; for (i = 0; i < N; ++i) { if (i < 8) { input[i] = _mm512_shuffle_f32x4(temp[i], temp[8 + i], 0x88); } else { input[i] = _mm512_shuffle_f32x4(temp[i - 8], temp[i], 0xdd); } - _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); } + + // store from registers to dst + if (M == 16) { + for (i = 0; i < N; ++i) { + _mm512_storeu_ps(&dst[i * ld_dst], input[i]); + } + } else { + __mmask16 dst_mask = (1 << M) - 1; + for (i = 0; i < N; ++i) { + _mm512_mask_storeu_ps(&dst[i * ld_dst], dst_mask, input[i]); + } + } +} + +template<> +inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst, int M, int N) { + int64_t i = 0; + for (; i < M / 16 * 16; i += 16) { + int64_t j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, 16); + } + // handle remainder j + int nrem = N - j; + if (nrem > 0) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, 16, nrem); + } + } + // handle remainder i + int mrem = M - i; + if (mrem > 0) { + int j = 0; + for (; j < N / 16 * 16; j += 16) { + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, 16); + } + // handle remainder j + int nrem = N - j; + transpose_mxn_16x16( + src + i * ld_src + j, ld_src, dst + j * ld_dst + i, ld_dst, mrem, nrem); + } +} + +template ::value, int> = 0> +inline void transpose_mxn(const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); } #endif diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_mask.h b/aten/src/ATen/cpu/vec/vec512/vec512_mask.h index 9ba1b18372eb54..cdb433af252541 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_mask.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_mask.h @@ -9,50 +9,139 @@ inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) -template +template struct VecMaskLoad< T, - 1, + dst_n, mask_t, - 1, + mask_n, typename std::enable_if_t< - std::is_same_v || std::is_same_v || - std::is_same_v, + (mask_n == dst_n * 2 && dst_n >= 1) && + (std::is_same_v || std::is_same_v), void>> { - static inline VectorizedN apply( + static inline VectorizedN apply( const T* ptr, - const VecMask& vec_mask) { + const VecMask& vec_mask) { at::vec::Vectorized zero_vec(0); auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); - auto int_mask = vec_mask.template cast()[0]; - auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); - if constexpr (std::is_same_v) { - return Vectorized(_mm512_mask_loadu_ps(zero_vec, mmask, ptr)); - } else { - return Vectorized(_mm512_mask_loadu_epi32(zero_vec, mmask, ptr)); + VectorizedN tmp_vec; + VectorizedN result; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int64_mask = VecMask(tmp_vec).template cast(); + auto int_mask = int64_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } } + return result; } }; -template +template +struct VecMaskLoad< + T, + dst_n, + mask_t, + dst_n, + typename std::enable_if_t< + std::is_same_v || std::is_same_v, + void>> { + static inline VectorizedN apply( + const T* ptr, + const VecMask& vec_mask) { + at::vec::Vectorized zero_vec(0); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast()[0]; + auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); + if constexpr (std::is_same_v) { + result[i] = Vectorized(_mm512_mask_loadu_ps( + zero_vec, mmask, ptr + i * Vectorized::size())); + } else { + result[i] = Vectorized(_mm512_mask_loadu_epi32( + zero_vec, mmask, ptr + i * Vectorized::size())); + } + } + return result; + } +}; + +template struct VecMaskLoad< data_t, - 1, + dst_n, mask_t, - 1, + dst_n, typename std::enable_if< std::is_same_v || std::is_same_v>::type> { - static inline VectorizedN apply( + static inline VectorizedN apply( const data_t* ptr, - const VecMask& vec_mask) { + const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); - auto int_mask = vec_mask.template cast()[0]; - auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); - auto zero = _mm256_set1_epi16(0); - auto temp = _mm256_mask_loadu_epi16(zero, mmask, ptr); - return Vectorized( - _mm512_inserti32x8(_mm512_castsi256_si512(temp), zero, 1)); + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < dst_n; i++) { + auto tmp_mask = VecMask(vec_mask[i]); + auto int_mask = tmp_mask.template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; + } +}; + +template +struct VecMaskLoad< + data_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n && dst_n >= 1) && + (std::is_same_v || std::is_same_v)>> { + static inline VectorizedN apply( + const data_t* ptr, + const VecMask& vec_mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + VectorizedN result; + VectorizedN tmp_vec; + for (int i = 0; i < dst_n; i++) { + tmp_vec[0] = vec_mask[2 * i]; + tmp_vec[1] = vec_mask[2 * i + 1]; + auto int_mask = VecMask(tmp_vec).template cast(); + auto mmask0 = _mm512_cmp_epi32_mask(int_mask[0], all_ones, _MM_CMPINT_EQ); + auto mmask1 = _mm512_cmp_epi32_mask(int_mask[1], all_ones, _MM_CMPINT_EQ); + auto zero = _mm256_set1_epi16(0); + auto temp0 = _mm256_mask_loadu_epi16( + zero, mmask0, ptr + (2 * i) * Vectorized::size()); + auto temp1 = _mm256_mask_loadu_epi16( + zero, mmask1, ptr + (2 * i + 1) * Vectorized::size()); + result[i] = Vectorized( + _mm512_inserti32x8(_mm512_castsi256_si512(temp0), temp1, 1)); + } + return result; } }; @@ -78,41 +167,155 @@ struct VecMaskLoad< } }; -template -struct VecMaskLoad { - static inline VectorizedN apply( - const int64_t* ptr, +template +struct VecMaskLoad< + data_t, + 2, + mask_t, + 1, + typename std::enable_if< + std::is_same_v || + std::is_same_v>::type> { + static inline VectorizedN apply( + const data_t* ptr, const VecMask& vec_mask) { auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); - auto zero = _mm512_set1_epi64(0); + at::vec::Vectorized zero_vec(0); auto int_mask = vec_mask.template cast()[0]; auto mmask = _mm512_cmp_epi32_mask(int_mask, all_ones, _MM_CMPINT_EQ); - at::vec::VectorizedN result; - result[0] = _mm512_mask_loadu_epi64(zero, (__mmask8)mmask, ptr); - result[1] = _mm512_mask_loadu_epi64(zero, (__mmask8)(mmask >> 8), ptr + 8); + at::vec::VectorizedN result; + if constexpr (std::is_same_v) { + result[0] = _mm512_mask_loadu_pd(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_pd(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } else { + result[0] = _mm512_mask_loadu_epi64(zero_vec, (__mmask8)mmask, ptr); + result[1] = + _mm512_mask_loadu_epi64(zero_vec, (__mmask8)(mmask >> 8), ptr + 8); + } return result; } }; -template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm512_castsi512_ps(vec_mask[0])); +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_ps(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castps_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castpd_si512(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result[i] = _mm512_castsi512_pd(vec_mask[i]); + } + return result; + } +}; + +template +struct VecMaskCast< + int64_t, + dst_n, + mask_t, + mask_n, + typename std::enable_if_t< + (dst_n == 2 * mask_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + auto int_mask = vec_mask.template cast(); +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < mask_n; ++i) { + auto int64_vec = + convert(VectorizedN(int_mask[i])); + result[2 * i] = int64_vec[0]; + result[2 * i + 1] = int64_vec[1]; + } + return VecMask(result); + } +}; + +template +struct VecMaskCast< + dst_t, + dst_n, + int64_t, + mask_n, + typename std::enable_if_t< + (mask_n == 2 * dst_n) && + (std::is_same_v || std::is_same_v), + void>> { + static inline VecMask apply( + const VecMask& vec_mask) { + VectorizedN result; + VectorizedN int64_vec; + for (int i = 0; i < dst_n; ++i) { + int64_vec[0] = vec_mask[2 * i]; + int64_vec[1] = vec_mask[2 * i + 1]; + result[i] = convert(int64_vec); + } + return VecMask(result).template cast(); } }; template <> -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - return Vectorized(_mm512_castps_si512(vec_mask[0])); +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); } }; -template -struct VecMaskCast { - static inline VecMask apply(const VecMask& vec_mask) { - auto int_vec = convert(VectorizedN(vec_mask)); - return VecMask(int_vec).cast(); +template <> +struct VecMaskCast { + static inline VecMask apply(const VecMask& vec_mask) { + auto int64_mask = VecMaskCast::apply(vec_mask); + return VecMaskCast::apply(int64_mask); } }; @@ -133,6 +336,41 @@ inline bool VecMask::all_masked() const { return mask == 0xffff; } +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + bool all_zero = true; + for (int i = 0; i < N; ++i) { + all_zero = + all_zero && (_mm512_test_epi64_mask(vec_mask[i], vec_mask[i]) == 0); + if (!all_zero) { + return all_zero; + } + } + return all_zero; + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + for (int j = 0; j < N; ++j) { + if (i < (j + 1) * 8) { + return _mm512_movepi64_mask(vec_mask[j]) & (1 << (i - j * 8)); + } + } + return false; + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + bool all_masked = true; + for (int i = 0; i < N; ++i) { + all_masked = all_masked && (_mm512_movepi64_mask(vec_mask[i]) == 0xff); + if (!all_masked) { + return all_masked; + } + } + return all_masked; + } +}; + #define VEC_MASK_METHOD_WITH_CAST_TO_INT( \ T, N, return_type, method, args_def, args) \ template <> \ diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index af4da53793666f..ba7865cb522f26 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -947,6 +947,17 @@ inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const return a * b - c; } +template +Vectorized inline operator&&( + const Vectorized& a, + const Vectorized& b) { + Vectorized ret; + for (int i = 0; i != Vectorized::size(); i++) { + ret[i] = a[i] && b[i]; + } + return ret; +} + template std::enable_if_t> inline gather(T const* base_addr, const Vectorized>& vindex) { @@ -979,7 +990,7 @@ inline mask_gather(const Vectorized& src, T const* base_addr, buffer[i] = src_arr[i]; } } - mask = Vectorized(); // "zero out" mask + mask = Vectorized(static_cast(0)); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } @@ -1126,8 +1137,8 @@ inline Vectorized flip(const Vectorized & data) { // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading // dimension of `src` and `ld_dst` is the leading dimension of `dst`. -template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { +template +inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { dst[j*ld_dst + i] = src[i*ld_src + j]; @@ -1135,6 +1146,11 @@ inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) } } +template +inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { + transpose_mxn(src, ld_src, dst, ld_dst, M, N); +} + }} // namespace at::vec::CPU_CAPABILITY // additional headers for more operations that depend on vec_base diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index ebec8d4a3e3c5d..a39ffa3090b8eb 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -2,7 +2,6 @@ #include #include - namespace at::vec { inline namespace CPU_CAPABILITY { @@ -69,7 +68,7 @@ struct VecMaskTo { } }; -template +template struct VecMaskCast { static inline VecMask apply( const VecMask& vec_mask) { @@ -84,6 +83,29 @@ struct VecMaskCast { } }; +template +struct VecMaskCheck { + static inline bool all_zero(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of( + mask, mask + VectorizedN::size(), [](T m) { return m == static_cast(0); }); + } + + static inline bool all_masked(const VectorizedN& vec_mask) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return std::all_of( + mask, mask + VectorizedN::size(), [](T m) { return m != static_cast(0); }); + } + + static inline bool is_masked(const VectorizedN& vec_mask, int i) { + __at_align__ T mask[VectorizedN::size()]; + vec_mask.store(mask); + return mask[i] != static_cast(0); + } +}; + template class VecMask { public: @@ -147,6 +169,17 @@ class VecMask { return result; } + static VecMask set( + const VecMask& a, + const VecMask& b, + int64_t count = size()) { + VectorizedN result = VectorizedN::set( + VectorizedN(a), + VectorizedN(b), + count); + return result; + } + void store(bool* b, int count = size()) { constexpr int L = (VectorizedN::size() + Vectorized::size() - 1)/ Vectorized::size(); auto res = this->to(); @@ -170,23 +203,15 @@ class VecMask { } inline bool all_zero() const { - __at_align__ T mask[size()]; - mask_.store(mask); - return std::all_of( - mask, mask + size(), [](T m) { return m == static_cast(0); }); + return VecMaskCheck::all_zero(mask_); } inline bool all_masked() const { - __at_align__ T mask[size()]; - mask_.store(mask); - return std::all_of( - mask, mask + size(), [](T m) { return m != static_cast(0); }); + return VecMaskCheck::all_masked(mask_); } inline bool is_masked(int i) const { - __at_align__ T mask[size()]; - mask_.store(mask); - return mask[i] != static_cast(0); + return VecMaskCheck::is_masked(mask_, i); } inline operator VectorizedN() const { diff --git a/aten/src/ATen/cpu/vec/vec_n.h b/aten/src/ATen/cpu/vec/vec_n.h index 2c0f2eef4b7ec6..8c4e622682a285 100644 --- a/aten/src/ATen/cpu/vec/vec_n.h +++ b/aten/src/ATen/cpu/vec/vec_n.h @@ -88,6 +88,9 @@ class VectorizedN { template = 0> VectorizedN(const Vectorized& val) : values({val}) {} + template = 0> + VectorizedN(const Vectorized& val_0, const Vectorized& val_1) : values({val_0, val_1}) {} + template = 0> inline operator Vectorized() const { return values[0]; diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index eea4a9f421a0c5..9b3fd5dc6e4dd9 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1408,7 +1408,6 @@ void scaled_gemm( const void *result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - void* amax_ptr, bool use_fast_accum) { #if CUDA_VERSION >= 11080 || defined(USE_ROCM) const auto computeType = CUBLAS_COMPUTE_32F; @@ -1421,13 +1420,9 @@ void scaled_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr); - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); -#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200) - // Amax support in ROCm as of 6.2 - if (isFloat8Type(result_dtype)) { - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr); + if (result_scale_ptr != nullptr) { + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } -#endif #ifndef USE_ROCM computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode); #endif diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 2c6cef95f79fe8..e6f0c5a9a373ba 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -140,7 +140,6 @@ void scaled_gemm( const void* result_scale_ptr, int64_t result_ld, ScalarType result_dtype, - void* amax_ptr, bool use_fast_accum); #define CUDABLAS_BGEMM_ARGTYPES(Dtype) \ diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index c88aaa04826ac7..511a0c2884587f 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -98,7 +98,7 @@ struct CUDACachingHostAllocatorImpl pinned_use_cuda_host_register()) { void* ptr = block->ptr_; AT_CUDA_CHECK(cudaHostUnregister(ptr)); - free(ptr); + std::free(ptr); } else { AT_CUDA_CHECK(cudaFreeHost(block->ptr_)); } @@ -123,6 +123,11 @@ struct CUDACachingHostAllocatorImpl return true; } + bool pinned_use_background_threads() override { + return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + pinned_use_background_threads(); + } + EventPool::Event create_event_internal(DeviceIndex idx) { // Leak the event pool to avoid shutdown issue. static auto* event_pool = new EventPool(); @@ -175,7 +180,7 @@ struct CUDACachingHostAllocatorImpl // Here we do regular allocation, pre-fault/map the pages, and then do // cudaHostRegister with GPU mapping flags to lock the pages, so we // can minimize the cost for the cuda global lock. - *ptr = malloc(roundSize); + *ptr = std::malloc(roundSize); // Parallelize the mapping/registering of pages to reduce wall time size_t pageSize = (1 << 12); // 4kB pages diff --git a/aten/src/ATen/cuda/cub-RadixSortKeys.cu b/aten/src/ATen/cuda/cub-RadixSortKeys.cu index 74e82ae55cdee0..56571295155268 100644 --- a/aten/src/ATen/cuda/cub-RadixSortKeys.cu +++ b/aten/src/ATen/cuda/cub-RadixSortKeys.cu @@ -50,7 +50,7 @@ void radix_sort_keys( int64_t begin_bit, \ int64_t end_bit); -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) +AT_FORALL_SCALAR_TYPES_AND3(Bool, BFloat16, Half, AT_INSTATIATE_CUB_TEMPLATES) AT_INSTATIATE_CUB_TEMPLATES(uint16_t, UInt16) AT_INSTATIATE_CUB_TEMPLATES(uint32_t, UInt32) AT_INSTATIATE_CUB_TEMPLATES(uint64_t, UInt64) diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 5053f6693b96da..1b7c898758558f 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -188,7 +188,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "ROCM_VERSION", [rocm_version]() { return rocm_version; }, - [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; }); + [rocm_version](auto&& k) { + TUNABLE_LOG1("ROCM_VERSION validation: expect ", k, " to match ", rocm_version); + return rocm_version == k ? OK : FAIL; + }); } // gfx arch { @@ -196,7 +199,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "GCN_ARCH_NAME", [gcn_arch_name]() { return gcn_arch_name; }, - [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; }); + [gcn_arch_name](auto&& k) { + TUNABLE_LOG1("GCN_ARCH_NAME validation: expect ", k, " to match ", gcn_arch_name); + return gcn_arch_name == k ? OK : FAIL; + }); } // rocblas { @@ -212,7 +218,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "ROCBLAS_VERSION", [rocblas_version]() { return rocblas_version; }, - [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; }); + [rocblas_version](auto&& k) { + TUNABLE_LOG1("ROCBLAS_VERSION validation: expect ", k, " to match ", rocblas_version); + return rocblas_version == k ? OK : FAIL; + }); } // hipblaslt { @@ -226,7 +235,10 @@ TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "HIPBLASLT_VERSION", [hipblaslt_version]() { return hipblaslt_version; }, - [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); + [hipblaslt_version](auto&& k) { + TUNABLE_LOG1("HIPBLASLT_VERSION validation: expect ", k, " to match ", hipblaslt_version); + return hipblaslt_version == k ? OK : FAIL; + }); } #endif } diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 50a7344b0260c2..00b02e91b4f359 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -104,7 +104,6 @@ class DefaultScaledGemmOp : public Callable> { params->c_scale_ptr, params->ldc, params->c_dtype, - params->amax_ptr, params->use_fast_accum); return OK; } diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index 7eefdfc7269cb6..61409db3ac6807 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -53,4 +53,4 @@ struct TORCH_API AcceleratorHooksInterface { }; } // namespace at -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 0a9004d353424d..180ff68588edd7 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -114,4 +114,4 @@ TORCH_API const MPSHooksInterface& getMPSHooks(); } // namespace detail } // namespace at -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index cee4463c818504..1480436fb4f1d8 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -104,6 +104,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); return nullptr; } + + virtual PyObject* getDeviceCapability(DeviceIndex device) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return nullptr; + } }; struct TORCH_API MTIAHooksArgs {}; diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp b/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp index ff267a41506bb2..258d05f87e6d2f 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.cpp @@ -11,13 +11,6 @@ TORCH_API void RegisterPrivateUse1HooksInterface(at::PrivateUse1HooksInterface* privateuse1_hooks = hook_; } -TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface() { - TORCH_CHECK( - privateuse1_hooks != nullptr, - "Please register PrivateUse1HooksInterface by `RegisterPrivateUse1HooksInterface` first."); - return privateuse1_hooks; -} - TORCH_API bool isPrivateUse1HooksRegistered() { return privateuse1_hooks != nullptr; } diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index 62fcb75d2a5ef2..e321f484deeace 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -12,7 +12,7 @@ namespace at { struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { ~PrivateUse1HooksInterface() override = default; virtual const at::Generator& getDefaultGenerator( - c10::DeviceIndex device_index) { + c10::DeviceIndex device_index) const { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`."); @@ -24,24 +24,26 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`."); } - bool isPinnedPtr(const void* data) const override { + virtual bool isPinnedPtr(const void* data) const override { return false; } - Allocator* getPinnedMemoryAllocator() const override { + virtual Allocator* getPinnedMemoryAllocator() const override { TORCH_CHECK( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`."); } - bool hasPrimaryContext(DeviceIndex device_index) const override { + virtual bool hasPrimaryContext(DeviceIndex device_index) const override { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); } virtual void initPrivateUse1() const {} - virtual void resizePrivateUse1Bytes(const c10::Storage &storage, size_t newsize) const { + virtual void resizePrivateUse1Bytes( + const c10::Storage& storage, + size_t newsize) const { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`."); @@ -53,8 +55,6 @@ struct TORCH_API PrivateUse1HooksArgs {}; TORCH_API void RegisterPrivateUse1HooksInterface( at::PrivateUse1HooksInterface* hook_); -TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface(); - TORCH_API bool isPrivateUse1HooksRegistered(); namespace detail { diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h index 1b7b7f99e46df2..f4cd9a34b5752b 100644 --- a/aten/src/ATen/detail/XPUHooksInterface.h +++ b/aten/src/ATen/detail/XPUHooksInterface.h @@ -81,4 +81,4 @@ namespace detail { TORCH_API const XPUHooksInterface& getXPUHooks(); } // namespace detail } // namespace at -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 5739e88d5ddcc1..cca20e9e553e5c 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -23,6 +23,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) { OP_DECOMPOSE(dropout_); OP_DECOMPOSE(feature_alpha_dropout_); OP_DECOMPOSE(feature_dropout_); + OP_DECOMPOSE(dropout); + OP_DECOMPOSE(_scaled_dot_product_attention_math); + OP_DECOMPOSE(scaled_dot_product_attention); } static void unsupportedData(const c10::OperatorHandle& op, torch::jit::Stack* stack) { @@ -227,7 +230,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { m.impl("reshape", native::reshape_symint); OP_DECOMPOSE(resolve_conj); OP_DECOMPOSE(resolve_neg); - OP_DECOMPOSE(rms_norm); + m.impl("rms_norm", native::rms_norm_symint); OP_DECOMPOSE(row_stack); OP_DECOMPOSE(rrelu); OP_DECOMPOSE(rrelu_); @@ -235,7 +238,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(relu6_); OP_DECOMPOSE(prelu); OP_DECOMPOSE2(softmax, int); - OP_DECOMPOSE(scaled_dot_product_attention); OP_DECOMPOSE(special_gammainc); OP_DECOMPOSE(special_gammaincc); OP_DECOMPOSE(special_logit); @@ -261,7 +263,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(special_xlogy); OP_DECOMPOSE2(special_xlogy, other_scalar); OP_DECOMPOSE2(special_xlogy, self_scalar); - OP_DECOMPOSE(_scaled_dot_product_attention_math); m.impl("split.sizes", native::split_symint); @@ -386,6 +387,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE2(to, dtype); OP_DECOMPOSE2(to, dtype_layout); OP_DECOMPOSE2(to, other); + + // Random ops that are also registered here + OP_DECOMPOSE(dropout); + OP_DECOMPOSE(_scaled_dot_product_attention_math); + OP_DECOMPOSE(scaled_dot_product_attention); } } // namespace at::functorch diff --git a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp index 6047a6eddb65d8..fed7fecc217b98 100644 --- a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp +++ b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp @@ -496,6 +496,11 @@ _scaled_dot_product_flash_attention_batch_rule( bool return_debug_mask, c10::optional scale ) { + if (dropout_p > 0) { + auto maybe_layer = maybeCurrentDynamicLayer(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value()); + } auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); auto query_ = moveBatchDimToFront(query, query_bdim); auto key_ = moveBatchDimToFront(key, key_bdim); @@ -540,6 +545,11 @@ fourOutputs _scaled_dot_product_efficient_attention_batch_rule( bool is_causal, c10::optional scale ) { + if (dropout_p > 0) { + auto maybe_layer = maybeCurrentDynamicLayer(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value()); + } auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); auto query_ = moveBatchDimToFront(query, query_bdim); auto key_ = moveBatchDimToFront(key, key_bdim); @@ -577,6 +587,11 @@ _scaled_dot_product_cudnn_attention_batch_rule( bool return_debug_mask, c10::optional scale ) { + if (dropout_p > 0) { + auto maybe_layer = maybeCurrentDynamicLayer(); + RandomnessType randomness = maybe_layer->randomness(); + check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value()); + } auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); auto query_ = moveBatchDimToFront(query, query_bdim); auto key_ = moveBatchDimToFront(key, key_bdim); diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 8626f4eb9fe4f9..e3e9a980f30b6c 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -779,6 +779,28 @@ std::tuple> scatter_reduce_batch_rule( self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce); } +std::tuple> scatter_reduce_two_batch_rule( + const Tensor& self, std::optional self_bdim, + int64_t dim, + const Tensor& index, std::optional index_bdim, + const Tensor& src, std::optional src_bdim, + const c10::string_view reduce, + bool include_self) { + return scatter_batch_rule(ATEN_FN2(scatter_reduce, two), + self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self); +} + +std::tuple> scatter_reduce__two_batch_rule( + const Tensor& self, std::optional self_bdim, + int64_t dim, + const Tensor& index, std::optional index_bdim, + const Tensor& src, std::optional src_bdim, + const c10::string_view reduce, + bool include_self) { + return scatter_batch_rule(ATEN_FN2(scatter_reduce_, two), + self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self); +} + std::tuple> scatter_value_reduce_batch_rule( const Tensor& self, std::optional self_bdim, int64_t dim, @@ -1250,6 +1272,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(scatter_add, scatter_add_batch_rule); VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule); VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule); + VMAP_SUPPORT2(scatter_reduce, two, scatter_reduce_two_batch_rule); + VMAP_SUPPORT2(scatter_reduce_, two, scatter_reduce__two_batch_rule); // as_strided_scatter does not work with the for-loop fallback today, // because as_strided_scatter will return an output that matches // the strides/storage_offset of its input. diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm index 26e4452eb7b3bd..1aafdbfd37a05f 100644 --- a/aten/src/ATen/mps/MPSFallback.mm +++ b/aten/src/ATen/mps/MPSFallback.mm @@ -88,7 +88,6 @@ static Tensor slow_conv2d_forward_mps(const Tensor& self, m.impl("embedding_renorm_", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); - m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in preprocessing by nn.Unfold m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps); m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp index d252a464cd4ac1..8b1182982002fa 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -117,7 +117,7 @@ namespace { return at::mkldnn_adaptive_avg_pool2d(input, C10_AS_INTARRAYREF_SLOW(output_size)); } - if (!input.is_quantized() && output_size[0] == 1 && output_size[1] == 1 && !input.is_xpu()) { + if (!input.is_quantized() && output_size[0] == 1 && output_size[1] == 1) { // in this case, adaptive pooling is just computing mean over hw // dimensions, which can be done more efficiently #if defined(C10_MOBILE) && defined(USE_XNNPACK) diff --git a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp index e5327185844fe4..4897864a378b71 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp @@ -313,7 +313,7 @@ Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_siz "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ", "but received {", output_size[0], ", ", output_size[1], ",", output_size[2], "}"); - if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1 && !input.is_xpu()) { + if (output_size[0] == 1 && output_size[1] == 1 && output_size[2] == 1) { // in this case, adaptive pooling is just computing mean over hw // dimensions, which can be done more efficiently Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true); diff --git a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp index 1c037c31c6f64d..c0f2399138cea1 100644 --- a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp @@ -297,7 +297,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu) int64_t osizeW = output_size[2]; if (input.ndimension() == 4) { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_cpu", [&] { auto input_data = input.const_data_ptr(); auto output_data = output.data_ptr(); @@ -320,7 +320,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_out_cpu) istrideW); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_cpu", [&] { auto input_data = input.const_data_ptr(); auto output_data = output.data_ptr(); @@ -390,7 +390,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu) /* backprop */ if (input.ndimension() == 4) { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_backward", [&] { /* get raw pointers */ scalar_t* gradInput_data = gradInput.data_ptr(); @@ -410,7 +410,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_cpu) osizeW); }); } else { - AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "adaptive_max_pool3d_backward", [&] { /* get raw pointers */ scalar_t* gradInput_data = gradInput.data_ptr(); diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 5c3ce52df87d44..d61c9870f4c522 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -1140,87 +1140,103 @@ REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_SVE256_DISPATCH(cholesky_stub, &cholesky_kernel); REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl); REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_SVE256_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_SVE256_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel); REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_SVE256_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel); REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_SVE256_DISPATCH(geqrf_stub, &geqrf_kernel); REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_SVE256_DISPATCH(orgqr_stub, &orgqr_kernel_impl); REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel); REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_SVE256_DISPATCH(ormqr_stub, &ormqr_kernel); REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel); REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_SVE256_DISPATCH(lstsq_stub, &lstsq_kernel); REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel); REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_SVE256_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel); REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel); +REGISTER_SVE256_DISPATCH(lu_factor_stub, &lu_factor_kernel); REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel); REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); +REGISTER_SVE256_DISPATCH(ldl_factor_stub, &ldl_factor_kernel); REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel); REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); +REGISTER_SVE256_DISPATCH(ldl_solve_stub, &ldl_solve_kernel); + REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel); REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel); +REGISTER_SVE256_DISPATCH(lu_solve_stub, &lu_solve_kernel); REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel); REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel); REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel); REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel); REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel); +REGISTER_SVE256_DISPATCH(svd_stub, &svd_kernel); REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel); REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); +REGISTER_SVE256_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 2e0dadfbabca65..0a2789f4caff8b 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -41,6 +41,17 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int * #include #endif // USE_FBGEMM +#if AT_MKLDNN_ENABLED() +#include +#endif // oneDNN + +#define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5) + +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) +#include +#include +#endif // oneDNN BRGEMM + namespace at::native::cpublas { namespace internal { @@ -822,4 +833,350 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex +struct UnsafeUkernelKeyHasher { + std::size_t operator()(const key_t& key) const; +}; + +template<> +std::size_t UnsafeUkernelKeyHasher::operator()(const BrgemmKey& key) const { + // Use beta, M, N, and K to compute hash to reduce the overhead as + // batch size, alpha, and data types are unlikely to change within the same kernel and + // leading dimensions are likely to be related to M, K, N or use fixed values. + std::size_t h = std::hash()(key.beta + 1); + h = std::hash()(key.M) ^ (h << 1); + h = std::hash()(key.N) ^ (h << 1); + h = std::hash()(key.K) ^ (h << 1); + h = std::hash()(key.ldc) ^ (h << 1); + return h; +} + +template<> +std::size_t UnsafeUkernelKeyHasher::operator()(const PackKey& key) const { + // Use K and N to compute hash to reduce the overhead as + // data types are unlikely to change and + // ld_in/ld_out is likely to be related to K, N or use fixed values + std::size_t h = std::hash()(key.K); + h = std::hash()(key.N) ^ (h << 1); + return h; +} + +template +struct KernelCache { + using kstore_t = std::unordered_map, UnsafeUkernelKeyHasher>; + static inline std::shared_ptr&& fetch_or_create( + const key_t& key, + const std::function()>& callback) { + auto&& search = get_store().find(key); + if (search != get_store().end()) { + return std::move(search->second); + } else { + get_store().insert({key, callback()}); + return std::move(get_store()[key]); + } + } + + static inline kstore_t& get_store() { + static thread_local kstore_t cache_kernels; + return cache_kernels; + } +}; + +// Helper struct for convenient brgemm configuration +struct GemmHelper { + GemmHelper( + int64_t M, + int64_t N, + int64_t K, + int64_t bs, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + ScalarType dt_a, + ScalarType dt_b, + ScalarType dt_c, + const float alpha, + const float beta) { + // Create brgemm + brg = dnnl::ukernel::brgemm( + M, + N, + K, + bs, + ld_a, + ld_b, + ld_c, + get_dnnl_dtype(dt_a), + get_dnnl_dtype(dt_b), + get_dnnl_dtype(dt_c), + alpha, + beta); + // Create a scratchpad buffer for the brgemm execution + scratchpad = std::vector(brg.get_scratchpad_size()); + // Prepare default vector of pairs of tensors A and B offsets for each batch. + A_B_offsets.reserve(1); + A_B_offsets[0] = std::make_pair(0, 0); + } + dnnl::ukernel::brgemm brg; + std::vector scratchpad; + std::vector> A_B_offsets; +}; + +struct Brgemm : public KernelCache { + // Fetch/create GemmHelper object and execute brgemm with batch size = 1 + template + static inline void call( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const scalar_t_a* A, + const scalar_t_b* B, + scalar_t_c* C) { + auto&& key = BrgemmKey( + M, + N, + K, + int64_t(1), + ld_a, + ld_b, + ld_c, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + alpha, + beta); + // Fetch/create GemmHelper object + auto&& value = fetch_or_create(key, [&]() { + auto&& v = std::make_shared( + M, + N, + K, + 1, + ld_a, + ld_b, + ld_c, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + alpha, + beta); + (*v).brg.generate(); + return std::move(v); + }); + if (get_current() != value) { + dnnl::ukernel::brgemm::release_hw_context(); + ((*value).brg).set_hw_context(); + get_current() = value; + } + ((*value).brg) + .execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data()); + } + + static inline std::shared_ptr& get_current() { + static thread_local std::shared_ptr current; + return current; + } + + static inline bool device_check(ScalarType dtype) { + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (dtype == ScalarType::Half) { + static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16; + return fp16_support; + } + return false; + } +}; + +using pack_t = dnnl::ukernel::brgemm_pack_B; +struct Pack : public KernelCache { + static inline void call( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out) { + auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out); + auto&& pack = fetch_or_create(key, [&]() { + auto&& p = std::make_shared( + K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out)); + if (need_pack(dt_in)) { + (*p).generate(); + } + return std::move(p); + }); + if (need_pack(dt_in)) { + (*pack).execute(in, out); + } else { + TORCH_CHECK(false, "No need to pack"); + } + } + + static inline bool need_pack(ScalarType dtype) { + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (dtype == ScalarType::Half) { + static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16; + return fp16_pack; + } + return false; + } +}; +#endif + +void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::Half* A, + const at::Half* B, + float* C) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + if (Brgemm::device_check(ScalarType::Half)) { + Brgemm::call( + M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C); + return; + } +#endif + TORCH_CHECK(false, + "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported"); +} + +void brgemm_release() { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + dnnl::ukernel::brgemm::release_hw_context(); +#endif +} + +void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out); +#else + TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled"); +#endif +} + +bool need_pack(ScalarType dt_in) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + return Pack::need_pack(dt_in); +#else + return false; +#endif +} + +} // namespace at::native::cpublas diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 3b30df1c21fad9..ad209329c95ecc 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -7,6 +7,7 @@ #include #include + namespace at::native::cpublas { namespace internal { @@ -186,4 +187,40 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); -} // namespace at::native::cpublas +// Batch-reduce GEMM +// Operates by the following formula: +// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size +// A Base pointer to a tensor A. +// B Base pointer to a tensor B. +// C Pointer to a tensor C (accumulation buffer). +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::Half* A, + const at::Half* B, + float* C); + +// Release brgemm hardware context +void brgemm_release(); + +// Pack B matrix to get better performance if needed +void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out); + +// Whether pack is needed in the platform. +bool need_pack(ScalarType dt_in); + +} // namespace at::native::cpublas diff --git a/aten/src/ATen/native/CPUFallback.cpp b/aten/src/ATen/native/CPUFallback.cpp index c1dc1f3a5eec6b..78222317a889ac 100644 --- a/aten/src/ATen/native/CPUFallback.cpp +++ b/aten/src/ATen/native/CPUFallback.cpp @@ -87,7 +87,11 @@ static bool validate_tensor_list(const c10::List& tensorlist) { return flag; } -void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views) { +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views, + c10::DispatchKey cpu_dispatch_key) { + TORCH_CHECK(c10::BackendComponent::CPUBit == c10::toBackendComponent(cpu_dispatch_key), + "Expected CPU backend DispatchKey but got ", + c10::toString(cpu_dispatch_key)); auto& schema_args = op.schema().arguments(); const auto num_arguments = schema_args.size(); auto arguments = torch::jit::last(stack, num_arguments); @@ -143,7 +147,7 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool } // Step 2: Call the underlying CPU implementation of the operator - op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack); + op.redispatchBoxed(c10::DispatchKeySet(cpu_dispatch_key), stack); // Step 3: We need to take special care to handle mutable aliases properly: // If any input tensors are mutable aliases, we need to diff --git a/aten/src/ATen/native/CPUFallback.h b/aten/src/ATen/native/CPUFallback.h index 606901fe1926fb..44cb534b8db2c9 100644 --- a/aten/src/ATen/native/CPUFallback.h +++ b/aten/src/ATen/native/CPUFallback.h @@ -11,7 +11,8 @@ namespace at::native { // This function implements a boxed fallback to CPU. // External backends can add their own custom logging on top if it to customize their own CPU fallbacks. -TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false); +TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool error_on_views = false, + c10::DispatchKey cpu_dispatch_key = c10::DispatchKey::CPU); // This is a helper function that backends can use to directly call their boxed CPU fallback // TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands. diff --git a/aten/src/ATen/native/Col2Im.cpp b/aten/src/ATen/native/Col2Im.cpp index 0a98ee5c46ca77..51e005c2901b93 100644 --- a/aten/src/ATen/native/Col2Im.cpp +++ b/aten/src/ATen/native/Col2Im.cpp @@ -144,7 +144,7 @@ static void col2im_out_cpu_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "col2im_out_cpu", [&] { Tensor input_n = Tensor(); Tensor output_n = Tensor(); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 97a51b19377003..2f85c9724372a8 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -421,12 +421,18 @@ struct ConvParams { // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) #if !defined(C10_MOBILE) - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } if (!detail::getCUDAHooks().compiledWithCuDNN()) { return false; } + if (needs_64bit_indexing_no_split(input, weight)) { + static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) { + TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions" + " if the V8 API is not enabled or before cuDNN version 9.3+." + " Consider upgrading cuDNN and/or enabling the V8 API for better efficiency."); + return false; + } + } if (!input.is_cuda() || !cudnn_enabled) { return false; } diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 699f2e0fe2e27c..b57649c2632592 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -34,6 +34,17 @@ static CPUCapability compute_cpu_capability() { if (strcmp(envar, "zvector") == 0) { return CPUCapability::ZVECTOR; } +#elif defined(HAVE_SVE_CPU_DEFINITION) + int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW. +#ifdef HAVE_SVE256_CPU_DEFINITION + if (strcmp(envar, "sve256") == 0) { + if (sve_vl == 256) { + return CPUCapability::SVE256; + } + TORCH_WARN("SVE256 capability not available on hardware. Falling back to DEFAULT"); + return CPUCapability::DEFAULT; + } +#endif #else #ifdef HAVE_AVX512_CPU_DEFINITION if (strcmp(envar, "avx512") == 0) { @@ -52,7 +63,7 @@ static CPUCapability compute_cpu_capability() { TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar); } -#if !defined(__powerpc__) && !defined(__s390x__) +#if !defined(__powerpc__) && !defined(__s390x__) && !defined(HAVE_SVE_CPU_DEFINITION) if (cpuinfo_initialize()) { #if defined(HAVE_AVX512_CPU_DEFINITION) // GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in @@ -79,6 +90,23 @@ static CPUCapability compute_cpu_capability() { } #endif +#if defined(__linux__) && defined(HAVE_SVE_CPU_DEFINITION) + if (cpuinfo_initialize() && cpuinfo_has_arm_sve()) { + int sve_vl = cpuinfo_get_max_arm_sve_length(); //Returns maximum SVE VL supported by your HW. + if (sve_vl <= 0) { + // SVE is not supported on this system. + // Return the default CPU capability. + return CPUCapability::DEFAULT; + } + #ifdef HAVE_SVE256_CPU_DEFINITION + if (sve_vl == 256) { // Check for SVE256 + return CPUCapability::SVE256; + } + #endif + // Return the default CPU capability. + return CPUCapability::DEFAULT; + } +#endif #ifdef HAVE_VSX_CPU_DEFINITION return CPUCapability::VSX; #else @@ -106,6 +134,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif ) { constexpr auto supported_devices = c10::array_of( c10::DeviceType::CPU, @@ -139,6 +170,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , SVE256 #endif ); if (!std::holds_alternative(result)) { @@ -191,6 +225,9 @@ void* DispatchStubImpl::get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif ) { auto result = try_get_call_ptr( @@ -211,6 +248,10 @@ void* DispatchStubImpl::get_call_ptr( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , + SVE256 #endif ); if (std::holds_alternative(result)) { @@ -242,6 +283,9 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ){ @@ -274,6 +318,16 @@ DispatchResult DispatchStubImpl::try_choose_cpu_impl( if (capability >= static_cast(CPUCapability::ZVECTOR)) { return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel; } +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE256)) { + if (C10_UNLIKELY(!SVE256)) { + // dispatch to DEFAULT, since the SVE kernel is missing + return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; + } else { + return DispatchResult(SVE256); + } + } #endif return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; } @@ -292,6 +346,9 @@ void* DispatchStubImpl::choose_cpu_impl( #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 +#endif ) { auto capability = static_cast(get_cpu_capability()); (void)capability; @@ -326,6 +383,17 @@ void* DispatchStubImpl::choose_cpu_impl( TORCH_INTERNAL_ASSERT(ZVECTOR, "DispatchStub: missing ZVECTOR kernel"); return ZVECTOR; } +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::SVE256)) { + if (C10_UNLIKELY(!SVE256)) { + // dispatch to DEFAULT, since the SVE kernel is missing + TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); + return DEFAULT; + } else { + return SVE256; + } + } #endif TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); return DEFAULT; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 641b779701d1e2..22a97ca9882b8e 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -64,6 +64,8 @@ enum class CPUCapability { VSX = 1, #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) ZVECTOR = 1, +#elif defined(HAVE_SVE_CPU_DEFINITION) + SVE256 = 1, #else AVX2 = 1, AVX512 = 2, @@ -112,6 +114,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -130,6 +135,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -148,6 +156,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -169,6 +180,9 @@ struct TORCH_API DispatchStubImpl { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , void *SVE256 #endif ); @@ -221,6 +235,9 @@ struct DispatchStub { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , reinterpret_cast(SVE256) #endif ) ); @@ -275,6 +292,9 @@ struct DispatchStub { #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) +#endif +#ifdef HAVE_SVE256_CPU_DEFINITION + , reinterpret_cast(SVE256) #endif ); if (std::holds_alternative(result)){ @@ -296,6 +316,9 @@ struct DispatchStub { #ifdef HAVE_ZVECTOR_CPU_DEFINITION static TORCH_API FnPtr ZVECTOR; #endif +#ifdef HAVE_SVE256_CPU_DEFINITION + static TORCH_API FnPtr SVE256; +#endif private: DispatchStubImpl impl; }; @@ -387,6 +410,12 @@ struct RegisterPRIVATEUSE1Dispatch { #define REGISTER_ZVECTOR_DISPATCH(name, fn) #endif +#ifdef HAVE_SVE256_CPU_DEFINITION +#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn) +#else +#define REGISTER_SVE256_DISPATCH(name, fn) +#endif + // Macro to register the same kernel for all CPU arch types. This is useful // if a kernel does not benefit from being recompiled across different arch types. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ @@ -394,7 +423,8 @@ struct RegisterPRIVATEUSE1Dispatch { REGISTER_AVX512_DISPATCH(name, fn) \ REGISTER_AVX2_DISPATCH(name, fn) \ REGISTER_VSX_DISPATCH(name, fn) \ - REGISTER_ZVECTOR_DISPATCH(name, fn) + REGISTER_ZVECTOR_DISPATCH(name, fn) \ + REGISTER_SVE256_DISPATCH(name, fn) #define REGISTER_NO_CPU_DISPATCH(name) \ REGISTER_ALL_CPU_DISPATCH(name, nullptr) @@ -432,12 +462,14 @@ struct RegisterPRIVATEUSE1Dispatch { #elif defined(CPU_CAPABILITY) // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others. +// ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others. #ifdef CPU_CAPABILITY_AVX512 #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr)) #else #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #endif } // namespace at::native diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 0de9632ddb1756..bd1d974e9bee50 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -585,19 +586,15 @@ Tensor& multinomial_out(const Tensor& self, // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 if (!with_replacement || n_sample == 1) { // Sanity checks on `self`. - auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item(); - TORCH_CHECK( - is_valid.to(), - "probability tensor contains either `inf`, `nan` or element < 0"); - bool zero_prob_condition = false; + auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)); + at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0"); + at::Tensor zero_prob_condition; if (self.dim() == 1){ - zero_prob_condition = (self.sum() == 0).item().to(); + zero_prob_condition = (self.sum() == 0); } else { - zero_prob_condition = (self.sum(1) == 0).sum().item().to(); + zero_prob_condition = (self.sum(1) == 0).any(); } - TORCH_CHECK( - !zero_prob_condition, - "invalid multinomial distribution (sum of probabilities <= 0)"); + at::_assert_async(~zero_prob_condition, "invalid multinomial distribution (sum of probabilities <= 0)"); // The algorithm is from gumbel softmax. // s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1) diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index f5c0672402f361..a8fbe13b8da0b9 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -128,10 +128,26 @@ inline bool _check_tensors_share_device_and_dtype( // corresponding tensors in tensor lists have the same sizes and strides. inline bool _check_tensors_share_sizes_and_strides( ArrayRef tensorLists) { + auto is_diff_stride = [](const IntArrayRef& size, + const IntArrayRef& left_stride, + const IntArrayRef& right_stride) -> bool { + const size_t size_size = size.size(); + for (const auto dim : c10::irange(size_size)) { + if (size[dim] == 1) + continue; + if (left_stride[dim] != right_stride[dim]) { + return true; + } + } + return false; + }; for (const auto i : c10::irange(1, tensorLists.size())) { for (const auto j : c10::irange(tensorLists[0].size())) { if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() || - tensorLists[0][j].strides() != tensorLists[i][j].strides()) { + is_diff_stride( + tensorLists[0][j].sizes(), + tensorLists[0][j].strides(), + tensorLists[i][j].strides())) { return false; } } diff --git a/aten/src/ATen/native/Im2Col.cpp b/aten/src/ATen/native/Im2Col.cpp index dac2ee6e3f103a..25eb4d6787240a 100644 --- a/aten/src/ATen/native/Im2Col.cpp +++ b/aten/src/ATen/native/Im2Col.cpp @@ -94,7 +94,7 @@ static void im2col_out_cpu_template( output.resize_({batch_size, n_output_plane, output_length}); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kBFloat16, kHalf, kBool, input.scalar_type(), "im2col_out_cpu", [&] { Tensor input_n; Tensor output_n; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 0611193357112e..2c4e000fce4421 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -1358,13 +1359,8 @@ static inline int64_t get_mkldnn_matmul_min_dim() { static auto value = [&] { const int64_t default_min_dim = [&] { // Minimum dimension requirement for MKLDNN; derived based on experiments. - // By default, it's only enabled on Neoverse V1. -#if !defined(__s390x__) && !defined(__powerpc__) - if (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1) { - return 8; - } -#endif - return 0; + //it's enabled on all Neoverse cpus. + return is_arm_neoverse() ? 8 : 0; }(); const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_DIM"); return ptr != nullptr ? std::atoi(ptr) : default_min_dim; @@ -1377,13 +1373,8 @@ static inline int64_t get_mkldnn_matmul_min_size() { static auto value = [&] { const int64_t default_min_size = [&] { // Minimum size requirement for MKLDNN; derived based on experiments. - // By default, it's only enabled on Neoverse V1. -#if !defined(__s390x__) && !defined(__powerpc__) - if (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1) { - return 8 * 1024; - } -#endif - return 0; + // it's enabled on all Neoverse cpus. + return is_arm_neoverse() ? 8 * 1024 : 0; }(); const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_SIZE"); return ptr != nullptr ? std::atoi(ptr) : default_min_size; diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 1fb64c32370509..796eac362b1246 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -209,7 +209,13 @@ std::tuple batch_norm_cpu_update_stats_template( bool all_contiguous = is_contiguous(input); constexpr bool mixed_type = !std::is_same_v; - const auto dtype = mixed_type ? kFloat : input.scalar_type(); + // Using float data type for Half _var_sum in batchnorm stats updating on CPU + // to avoid _var_sum overflow since the representation range of Half is small. + using opmath_t = std::conditional_t, at::opmath_type, param_t>; + auto dtype = mixed_type ? kFloat : input.scalar_type(); + if (dtype == kHalf) { + dtype = kFloat; + } auto save_mean_a = save_mean.accessor(); auto save_var_transform_a = save_var_transform.accessor(); @@ -220,9 +226,9 @@ std::tuple batch_norm_cpu_update_stats_template( if (all_contiguous) { auto _mean = at::empty({n_input}, input.options().dtype(dtype)); auto _var_sum = at::empty({n_input}, input.options().dtype(dtype)); - auto _mean_a = _mean.accessor(); - auto _var_sum_a = _var_sum.accessor(); - auto momentum_ = static_cast(momentum); + auto _mean_a = _mean.accessor(); + auto _var_sum_a = _var_sum.accessor(); + auto momentum_ = static_cast(momentum); batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input); @@ -552,7 +558,7 @@ std::tuple _batch_norm_impl_index( if (input.sym_numel() == 0) { Tensor reserve = at::empty({0}, input.options().dtype(kByte)); auto options = input.options().dtype( - at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda())); + at::toAccumulateType(input.scalar_type(), input.device().type())); auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options); auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options); diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index d1bb95f833231f..9db7b4cb7da090 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -15,6 +15,9 @@ #include #include #include +#if AT_MKLDNN_ENABLED() +#include +#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -97,7 +100,10 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) { }; return input.options().backend() == at::Backend::CPU && is_cpu_backend(params) && is_cpu_backend(hx) && - (input.scalar_type() == kFloat || input.scalar_type() == kBFloat16) && + (input.scalar_type() == kFloat || + (input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) || + (input.scalar_type() == kHalf && !at::GradMode::is_enabled() && + mkldnn_fp16_device_check())) && input.numel() != 0; #endif return false; diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index c73ed0cce0562b..bbc50c3c2fca75 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -1414,6 +1414,15 @@ Tensor& mean_out(const Tensor& self, DimnameList dim, return at::mean_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype); } +Tensor& mean_dtype_out(const Tensor &self, std::optional dtype, Tensor& result) { + TORCH_CHECK( + canCast(self.scalar_type(), result.scalar_type()), + "mean.dtype_out(): input types can't be cast to the desired output type ", + result.scalar_type()); + // at::mean_out should make sure dtype and result.scalar_type() are the same + return at::mean_out(result, self, IntArrayRef{}, false, dtype); +} + // TODO(@heitorschueroff) implement custom kernels for nanmean Tensor& nanmean_out( const Tensor& self, @@ -1447,7 +1456,9 @@ Tensor nanmean( static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) { // can't take max of empty tensor if (self.numel() != 0) { - auto maxes = at::amax(self, dims, true); + // For complex numbers, use the real part to calculate the max. Based on + // https://scicomp.stackexchange.com/questions/34273/log-sum-exp-trick-for-signed-complex-numbers + auto maxes = at::amax(at::real(self), dims, true); auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims)); maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0); at::sum_out(result, (self - maxes).exp_(), dims, keepdim); @@ -1460,7 +1471,8 @@ static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRe } Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) { - TORCH_CHECK(at::isFloatingType(result.scalar_type()), + // Complex type implies floating point type + TORCH_CHECK(at::isFloatingType(result.scalar_type()) || at::isComplexType(result.scalar_type()), "logsumexp(): Expected floating point type for result tensor, but got: ", result.scalar_type()); { diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index f4d895c2c58584..f64b9cc3cefa34 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -282,9 +282,9 @@ void resize_bytes_nocuda(const Storage& storage, const c10::SymInt& newsize) { } else if (device_type == at::kMeta) { at::native::resize_bytes_meta(storage.unsafeGetStorageImpl(), newsize); } else if (device_type == at::kPrivateUse1) { - at::GetPrivateUse1HooksInterface()->resizePrivateUse1Bytes( + at::detail::getPrivateUse1Hooks().resizePrivateUse1Bytes( storage, newsize.expect_int()); - } else if (device_type == at::kXPU || device_type == at::kHPU) { + } else if (device_type == at::kXPU || device_type == at::kHPU || device_type == at::kMTIA) { ptrdiff_t size_bytes_i = newsize.expect_int(); TORCH_CHECK( !c10::overflows(size_bytes_i), diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 0ab01bbe8c0bd2..08220869701407 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -466,6 +466,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cp REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); +REGISTER_SVE256_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel); // offsets dispatches REGISTER_ARCH_DISPATCH( @@ -476,6 +477,7 @@ REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cp REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); +REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel); // Currently some computation is being duplicated across forward and backward. // TODO: Cache indices in forward pass to re-use in backward @@ -546,6 +548,9 @@ REGISTER_VSX_DISPATCH( REGISTER_ZVECTOR_DISPATCH( _segment_reduce_lengths_backward_stub, &_segment_reduce_cpu_lengths_backward_kernel); +REGISTER_SVE256_DISPATCH( + _segment_reduce_lengths_backward_stub, + &_segment_reduce_cpu_lengths_backward_kernel); REGISTER_ARCH_DISPATCH( _segment_reduce_offsets_backward_stub, @@ -563,5 +568,8 @@ REGISTER_VSX_DISPATCH( REGISTER_ZVECTOR_DISPATCH( _segment_reduce_offsets_backward_stub, &_segment_reduce_cpu_offsets_backward_kernel); +REGISTER_SVE256_DISPATCH( + _segment_reduce_offsets_backward_stub, + &_segment_reduce_cpu_offsets_backward_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 105636c0a6b746..33d5d19d8b888a 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -590,9 +590,9 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) } } - // For CUDA/MPS tensors, force all index tensors to have the same striding to - // simplify the CUDA/MPS kernel. - if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS)) { + // For CUDA/MPS/XPU tensors, force all index tensors to have the same striding to + // simplify the CUDA/MPS/XPU kernel. + if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS || this->src.device().type() == kXPU)) { if (!all_strides_match(indices)) { for (auto & indice : indices) { indice = indice.contiguous(); @@ -1588,7 +1588,7 @@ static bool can_use_expanded_index_path( } const auto st = self.scalar_type(); - if (!(c10::isFloatingType(st)) || st == ScalarType::Half) { + if (!(c10::isFloatingType(st))) { return false; } @@ -1808,7 +1808,7 @@ void scatter_impl( if (index.numel() == 0) return; auto op = ReductionType::SUM; - bool deterministic = globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA; + bool deterministic = globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU); if (reduce.has_value()) { op = get_operator_enum(reduce.value(), use_new_options); @@ -2284,12 +2284,20 @@ int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) { } Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){ - return (self != 0).sum(dims); + auto reduce = self; + if (reduce.scalar_type() != kBool) { + reduce = reduce != 0; + } + return reduce.sum(dims); } Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){ if (!dims.empty()) { - return (self != 0).sum(dims); + auto reduce = self; + if (reduce.scalar_type() != kBool) { + reduce = reduce != 0; + } + return reduce.sum(dims); } // Optimized all-reduce diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 6d6db1477f1f8f..c82e429621812f 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -82,7 +82,6 @@ namespace at::meta { static inline void check_for_unsupported_isin_dtype(const ScalarType type) { // Bail out for dtypes unsupported by the sorting algorithm to keep the interface consistent. TORCH_CHECK(type != ScalarType::Bool && - type != ScalarType::BFloat16 && type != ScalarType::ComplexFloat && type != ScalarType::ComplexDouble, "Unsupported input type encountered for isin(): ", type); diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index dc0c1054d16b05..22a576408bfbbb 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -772,9 +772,6 @@ inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_str } Tensor view_dtype(const Tensor& self, ScalarType dtype) { - if (self.scalar_type() == dtype) { - return self; - } const auto type_meta = c10::scalarTypeToTypeMeta(dtype); TORCH_CHECK(!self.is_conj(), "torch.Tensor.view is not supported for conjugate view tensors when converting to a different dtype."); diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp index aa2bab7c6b9455..bdc3a554b85b7e 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp @@ -11,18 +11,18 @@ namespace ao { namespace sparse { namespace { -constexpr int64_t serialization_version_index = 0; -constexpr int64_t bias_index = 1; -constexpr int64_t out_features_block_size_index = 2; -constexpr int64_t in_features_block_size_index = 3; -constexpr int64_t weight_scales_index = 4; -constexpr int64_t weight_zero_point_index = 5; -constexpr int64_t quantization_scheme_index = 6; -constexpr int64_t row_block_indices_index = 7; -constexpr int64_t col_block_indices_index = 8; -constexpr int64_t weight_values_index = 9; -constexpr int64_t num_output_channels_index = 10; -constexpr int64_t num_input_channels_index = 11; +constexpr int64_t serialization_version_index [[maybe_unused]] = 0; +constexpr int64_t bias_index [[maybe_unused]] = 1; +constexpr int64_t out_features_block_size_index [[maybe_unused]] = 2; +constexpr int64_t in_features_block_size_index [[maybe_unused]] = 3; +constexpr int64_t weight_scales_index [[maybe_unused]] = 4; +constexpr int64_t weight_zero_point_index [[maybe_unused]] = 5; +constexpr int64_t quantization_scheme_index [[maybe_unused]] = 6; +constexpr int64_t row_block_indices_index [[maybe_unused]] = 7; +constexpr int64_t col_block_indices_index [[maybe_unused]] = 8; +constexpr int64_t weight_values_index [[maybe_unused]] = 9; +constexpr int64_t num_output_channels_index [[maybe_unused]] = 10; +constexpr int64_t num_input_channels_index [[maybe_unused]] = 11; template std::vector unwrap_vector(at::Tensor tensor) { diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 51417d0e6851eb..838b7fbd097fed 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -81,6 +81,12 @@ void atan2_kernel(TensorIteratorBase& iter) { } #if !defined(C10_MOBILE) +#define _AT_DISPATCH_INTEGRAL_TYPES_V2(TYPE, NAME, ...) \ + AT_DISPATCH_V2( \ + TYPE, \ + NAME, \ + AT_WRAP(__VA_ARGS__), \ + AT_EXPAND(AT_INTEGRAL_TYPES_V2)) #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \ AT_DISPATCH_V2( \ TYPE, \ @@ -104,6 +110,8 @@ void atan2_kernel(TensorIteratorBase& iter) { AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else +#define _AT_DISPATCH_INTEGRAL_TYPES_V2(TYPE, NAME, ...) \ + AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, __VA_ARGS__) #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ kComplexHalf, kHalf, kBool, kBFloat16, TYPE, NAME, __VA_ARGS__) @@ -382,7 +390,7 @@ void bitwise_and_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [](bool a, bool b) { return a && b; }); } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_and_cpu", [&]() { + _AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_and_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return a & b; }, @@ -395,7 +403,7 @@ void bitwise_or_kernel(TensorIteratorBase& iter) { if (iter.dtype() == ScalarType::Bool) { cpu_kernel(iter, [](bool a, bool b) { return a || b; }); } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_or_cpu", [&]() { + _AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_or_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return a | b; }, @@ -410,7 +418,7 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) { // this operation for both Boolean and integral types. cpu_kernel(iter, [](bool a, bool b) { return a != b; }); } else { - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_xor_cpu", [&]() { + _AT_DISPATCH_INTEGRAL_TYPES_V2(iter.dtype(), "bitwise_xor_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return a ^ b; }, diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp index 601b098d37c0ab..90e34fc70b81f4 100644 --- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp +++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp @@ -16,7 +16,6 @@ #else #include #endif - namespace at::native { namespace { @@ -202,7 +201,97 @@ void reshape_attn_mask_to_4d( .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); } -template +template +inline void copy_value_with_pad( + const scalar_t* value_ptr, + scalar_t* dst_ptr, + int64_t rows, + int64_t cols, + int64_t prows, + int64_t pcols, + int64_t ldi) { + auto vec_size = at::vec::Vectorized::size(); + int64_t i = 0; + for (; i < rows; i++) { + int64_t j = 0; + for (; j < cols - (cols % vec_size); j += vec_size) { + auto vec_v = + at::vec::Vectorized::loadu(value_ptr + i * ldi + j); + vec_v.store(dst_ptr + i * pcols + j); + } + + if (j < cols) { + auto vec_v = at::vec::Vectorized::loadu( + value_ptr + i * ldi + j, cols - j); + vec_v.store(dst_ptr + i * pcols + j, cols - j); + } + + // col padding + auto psize = pcols - cols; + if (psize > 0) { + auto zero_vec = at::vec::Vectorized(0); + int64_t pj = 0; + for (; pj < psize - (psize % vec_size); pj += vec_size) { + zero_vec.store(dst_ptr + i * pcols + cols + pj); + } + if (pj < psize) { + zero_vec.store(dst_ptr + i * pcols + cols + pj, psize - pj); + } + } + } + // row padding + for (; i < prows; i++) { + auto zero_vec = at::vec::Vectorized(0); + int64_t j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero_vec.store(dst_ptr + i * pcols + j); + } + if (j < pcols) { + zero_vec.store(dst_ptr + i * pcols + j, pcols - j); + } + + } +} + +template +inline void pad_remain_row_col_zero( + scalar_t* value_ptr, + int rows, + int cols, + int prows, + int pcols, + int ldi) { + auto psize = pcols - cols; + if (psize == 0 && prows == rows) { + return; + } + auto vec_size = at::vec::Vectorized::size(); + auto zero = at::vec::Vectorized(0); + if (psize > 0) { + for (int i = 0; i < rows; i++) { + int j = 0; + for (; j < psize - (psize % vec_size); j += vec_size) { + zero.store(value_ptr + i * ldi + cols + j); + } + if (j < psize) { + zero.store(value_ptr + i * ldi + cols + j, psize - j); + } + } + } + + for (int i = rows; i < prows; i++) { + int j = 0; + for (; j < pcols - (pcols % vec_size); j += vec_size) { + zero.store(value_ptr + i * ldi + j); + } + if (j < pcols) { + zero.store(value_ptr + i * ldi + j, pcols - j); + } + } + +} + +template void cpu_flash_attention( const Tensor& output, const Tensor& logsumexp, @@ -278,21 +367,70 @@ void cpu_flash_attention( int64_t qSplitSize = q_split_size > qSize ? qSize : q_split_size; int64_t kvSplitSize = kv_split_size > kvSize ? kvSize : kv_split_size; - int64_t qSlice = (qSize - 1) / qSplitSize + 1; + int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize; + int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize; + int64_t kvTail = (kvSize - 1) % kvSplitSize + 1; int64_t num_thread = at::get_num_threads(); const auto dtype = query.scalar_type(); const auto accumulate_dtype = toOpMathType(dtype); + // Whether pack is needed + bool need_pack = false; + // Block size of packing B matrix + int64_t packb_size = 64; + // Use packb_size due to the limitation: + // oneDNN pack only supports output leading dimention being one of (16, 32, 48, 64) + // For instance, + // for q @ k.T [qSplitSize, headSize] * [headSize, kvSplitSize] = [qSplitSize, kvSplitSize], + // we need to split kvSplitSize with packb_size for packing k.T, + // for (q @ k.T) @ v [qSplitSize, kvSplitSize] x [kvSplitSize, headSize] -> [qSplitSize, headSize], + // we need to split headSize with packb_size for packing v + // TODO Simplify the check when oneDNN supports fused pack with transpose and has better performance + if (with_pack) { + need_pack = num_head >= 4 && headSize % packb_size == 0 && kvSize >= packb_size; + if (need_pack) { + float pack_size = batchSize * num_head * kvSize * headSize / 1024; + float gemm_size_per_thread = + (batchSize * num_head * qSlice + num_thread - 1) / num_thread * + qSplitSize * (is_causal ? qSize : kvSize) * headSize / 1024; + float gsize = gemm_size_per_thread / pack_size; + // When the number of gemm is much greater than the number of pack, + // the pack and padding overhead can be overlaped. + if (pack_size < 2688) { + need_pack = gsize >= 36 || (gsize >= 24 && headSize > packb_size); + } else if (pack_size < 16384) { + need_pack = gsize >= (is_causal ? 54 : 52); + } else { + need_pack = gsize >= (is_causal ? 54 : 40); + } + } + } + + int64_t rHeadSize = need_pack ? (headSize + packb_size - 1) / packb_size * packb_size : headSize; + int64_t rkvSplitSize = need_pack ? (kvSplitSize + packb_size - 1) / packb_size * packb_size : kvSplitSize; + int64_t rkvTail = need_pack ? (kvTail + packb_size - 1) / packb_size * packb_size : kvTail; + int64_t rkvSize = kv_split_size > kvSize ? rkvTail : rkvSplitSize * kvSlice + rkvTail; + + // oneDNN pack does not support odd K now, we need also pad odd K + bool headSize_even = headSize % 2 == 0; + int64_t eheadSize = need_pack && !headSize_even ? headSize + 1: headSize; + int64_t ekvSplitSize = need_pack && (kvSplitSize % 2 != 0) ? kvSplitSize + 1 : kvSplitSize; + int64_t ekvTail = need_pack && (kvTail % 2 != 0) ? kvTail + 1 : kvTail; + // allocate per thread temp buf (accumulate type) int64_t size_per_thread = - /* qk */ qSplitSize * kvSplitSize + + /* qk */ qSplitSize * rkvSplitSize + /* qk_max */ qSplitSize + /* qk_sum */ qSplitSize + - /* dst */ qSplitSize * headSize; + /* dst */ qSplitSize * rHeadSize; at::Tensor buf = at::empty({num_thread, size_per_thread}, query.options().dtype(accumulate_dtype)); - at::Tensor buf_reduced = at::empty({num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, query.options()); + at::Tensor buf_reduced = at::empty( + {num_thread, + qSplitSize, + is_reduced_type ? ekvSplitSize : 0}, + query.options()); // Data ptrs const scalar_t* q_data = query.const_data_ptr(); @@ -306,19 +444,129 @@ void cpu_flash_attention( accum_t* buf_data = buf.data_ptr(); scalar_t* buf_reduced_data = is_reduced_type ? buf_reduced.data_ptr() : nullptr; + // Buffer to store padding query + scalar_t* query_padding_ptr = nullptr; + std::unique_ptr query_padding_data; + if (!headSize_even && need_pack) { + query_padding_data = std::make_unique(num_thread * qSplitSize * eheadSize); + query_padding_ptr = query_padding_data.get(); + } + // Buffer to store Key and Value after transforms + scalar_t* key_reorder_ptr = nullptr; + std::unique_ptr key_reorder_data; + scalar_t* value_reorder_ptr = nullptr; + std::unique_ptr value_reorder_data; + int kv_padding_size = (kvSize - 1) / kvSplitSize * ekvSplitSize + ekvTail; + if (need_pack) { + key_reorder_data = std::make_unique(batchSize * num_head * eheadSize * rkvSize); + key_reorder_ptr = key_reorder_data.get(); + value_reorder_data = std::make_unique(batchSize * num_head * kv_padding_size * rHeadSize); + value_reorder_ptr = value_reorder_data.get(); + } + + // Reorder K, V + if (need_pack) { + at::parallel_for(0, batchSize * num_head * kvSlice, 1, [&](int64_t begin, int64_t end) { + int64_t i = 0, j = 0, l = 0, n = 0; + at::native::data_index_init(begin, i, batchSize, j, num_head, l, kvSlice); + std::unique_ptr transpose_buffer = std::make_unique(eheadSize * packb_size); + scalar_t* transpose_buffer_ptr = transpose_buffer.get(); + std::unique_ptr v_copy_buffer = std::make_unique(ekvSplitSize * packb_size); + scalar_t* v_copy_buffer_ptr = v_copy_buffer.get(); + for (C10_UNUSED auto z : c10::irange(begin, end)) { + n = l * kvSplitSize; + int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1; + + // Split kvSplitSize with packb_size + // [kvSplitSize, headSize] -> [div_up(kvSplitSize, packb_size), packb_size, headSize] + // Transpose [packb_size, headSize] -> [headSize, packb_size] + // Pack transposed buffer + + for (int64_t b = 0; b < kvBlockSize; b += packb_size) { + bool tail = kvBlockSize - b < packb_size; + // TODO Use fused pack with transpose support when oneDNN supports such usage + utils::transpose( + tail ? kvBlockSize - b : packb_size, + headSize, + /* src_ptr */ + reinterpret_cast( + k_data + i * kStrideB + j * kStrideH + n * kStrideN + + b * kStrideN), + /* ld_src */ kStrideN, + /* dst */ reinterpret_cast(transpose_buffer_ptr), + /* ld_dst */ packb_size); + // Pad [headSize, x] -> [eheadSize, x] + if (!headSize_even) { + pad_remain_row_col_zero( + transpose_buffer_ptr, + headSize, + packb_size, + eheadSize, + packb_size, + packb_size); + } + // Pack + cpublas::pack( + /* K */ eheadSize, + /* N */ packb_size, + /* ld_in */ packb_size, + /* ld_out */ packb_size, + /* dt_in */ dtype, + /* dt_out */ dtype, + transpose_buffer_ptr, + key_reorder_ptr + i * num_head * eheadSize * rkvSize + + j * eheadSize * rkvSize + n * eheadSize + b * eheadSize); + } + + // Split headSize with packb_size + // [kvSplitSize, headSize] -> [kvSplitSize, div_up(headSize, packb_size), packb_size] + for (int64_t b = 0; b < headSize; b += packb_size) { + // Do copy due to the limitation of input_ld of oneDNN pack: + // Regarding packing [K, N], only input_ld == N is supported + // TODO: remove the copy when pack supports input_ld >= N + copy_value_with_pad( + v_data + i * vStrideB + j * vStrideH + n * vStrideN + b, + v_copy_buffer_ptr, + kvBlockSize, + (headSize - b < packb_size) ? headSize - b : packb_size, + ekvBlockSize, + packb_size, + vStrideN); + cpublas::pack( + ekvBlockSize, + packb_size, + packb_size, + packb_size, + dtype, + dtype, + v_copy_buffer_ptr, + value_reorder_ptr + + i * num_head * kv_padding_size * rHeadSize + + j * kv_padding_size * rHeadSize + n * rHeadSize + + ekvBlockSize * b); + } + // Move to the next query + at::native::data_index_step(i, batchSize, j, num_head, l, kvSlice); + } + }); + } + at::parallel_for(0, batchSize * num_head * qSlice, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, j = 0, k = 0; data_index_init(begin, i, batchSize, j, num_head, k, qSlice); int ompIdx = at::get_thread_num(); accum_t* buf_ptr = buf_data + ompIdx * size_per_thread; accum_t* qk_data = buf_ptr; - accum_t* qk_max_data = qk_data + qSplitSize * kvSplitSize; + accum_t* qk_max_data = qk_data + qSplitSize * rkvSplitSize; accum_t* qk_sum_data = qk_max_data + qSplitSize; accum_t* dst_data = qk_sum_data + qSplitSize; - scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * kvSplitSize : nullptr; + scalar_t* qk_reduced_data = is_reduced_type ? buf_reduced_data + ompIdx * qSplitSize * ekvSplitSize : nullptr; + scalar_t* query_t_padding_ptr = (!headSize_even && need_pack) + ? query_padding_ptr + ompIdx * qSplitSize * eheadSize + : nullptr; - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable + for (C10_UNUSED auto z : c10::irange(begin, end)) { int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize max and sum @@ -327,10 +575,46 @@ void cpu_flash_attention( fill_stub(qk_sum_data, static_cast(0), qBlockSize); int64_t num_keys = is_causal ? std::min(m + qBlockSize, kvSize) : kvSize; + if (!headSize_even && need_pack) { + // Pad query if headSize is not even + // [qBlockSize, headSize] -> [qBlockSize, eheadSize] + copy_value_with_pad( + q_data + i * qStrideB + j * qStrideH + m * qStrideM, + query_t_padding_ptr, + qBlockSize, + headSize, + qBlockSize, + eheadSize, + qStrideM + ); + } for (int64_t n = 0; n < num_keys; n += kvSplitSize) { int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); + int64_t ekvBlockSize = (need_pack && kvBlockSize % 2 != 0) ? kvBlockSize + 1 : kvBlockSize; + int64_t rkvBlockSize = kvBlockSize == kvSplitSize ? rkvSplitSize : rkvTail; // Calculate scale * q @ k.T - cpublas::gemm( + if (need_pack) { + if constexpr (std::is_same_v) { + for (int64_t b = 0; b < kvBlockSize; b += packb_size) { + cpublas::brgemm( + qBlockSize, + packb_size, + eheadSize, + headSize_even ? qStrideM : eheadSize, + packb_size, + rkvBlockSize, + 1.f, + 0.f, + !headSize_even + ? query_t_padding_ptr + : q_data + i * qStrideB + j * qStrideH + m * qStrideM, + key_reorder_ptr + i * num_head * eheadSize * rkvSize + + j * eheadSize * rkvSize + n * eheadSize + b * eheadSize, + qk_data + b); + } + } + } else { + cpublas::gemm( TransposeType::Transpose, TransposeType::NoTranspose, kvBlockSize, @@ -346,11 +630,12 @@ void cpu_flash_attention( static_cast(0), qk_data, kvBlockSize); + } // Apply causal mask, fill unused with -inf if (is_causal && num_keys - n <= kvSplitSize) { for (const auto row : c10::irange(qBlockSize)) { int64_t last_col = m + row - n; - accum_t* row_ptr = qk_data + row * kvBlockSize; + accum_t* row_ptr = qk_data + row * rkvBlockSize; fill_stub(row_ptr + last_col + 1, -std::numeric_limits::infinity(), kvBlockSize - last_col - 1); @@ -363,29 +648,29 @@ void cpu_flash_attention( for (int64_t row = 0; row < qBlockSize; ++row) { #if __GNUC__ == 11 && __GNUC_MINOR__ >= 4 && defined(__ARM_FEATURE_SVE) _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM + (mStrideN == 0 ? 0 : n), kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor, mStrideN == 0); #else if (mStrideN == 0) { _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM, kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor); } else { _scale_attn_mask_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, mask_data + i * mStrideB + j * mStrideH + (m + row) * mStrideM + n, kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor); } #endif @@ -398,28 +683,28 @@ void cpu_flash_attention( // max per row tmp_max = at::vec::reduce_all( [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, kvBlockSize); } else { // apply scaling factor and max per row in fusion _mul_reduce_max_fusion_kernel( - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, scaling_factor, kvBlockSize, - qk_data + row * kvBlockSize, + qk_data + row * rkvBlockSize, tmp_max); } tmp_max = qk_max_data[row] > tmp_max ? qk_max_data[row] : tmp_max; if (tmp_max == -std::numeric_limits::infinity()) { // to avoid `nan = exp2f(-inf - (-inf))` - fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize, + fill_stub(conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize, static_cast(0), kvBlockSize); } else { tmp_sum = tmp_max; // qk <- exp(qk - max) and sum per row _exp_reduce_sum_fusion_kernel( - qk_data + row * kvBlockSize, kvBlockSize, - conditional_data_ptr(qk_data, qk_reduced_data) + row * kvBlockSize, + qk_data + row * rkvBlockSize, kvBlockSize, + conditional_data_ptr(qk_data, qk_reduced_data) + row * ekvBlockSize, tmp_sum); // exp_tmp <- exp(max[row] - max) exp_tmp = std::exp(qk_max_data[row] - tmp_max); @@ -431,12 +716,40 @@ void cpu_flash_attention( if (n > 0) { vec::map( [exp_tmp](Vec x) { return x * Vec(exp_tmp); }, - dst_data + row * headSize, dst_data + row * headSize, headSize); + dst_data + row * rHeadSize, + dst_data + row * rHeadSize, + headSize); } } + if (need_pack && kvBlockSize % 2 != 0) { + // Pad: [qSplitSize,kvSplitSize] -> [qSplitSize,kvSplitSize + 1] + *(qk_reduced_data + row * (1 + kvBlockSize) + kvBlockSize) = scalar_t(0); + } } // Calculate Softmax(q @ k.T) @ v - cpublas::gemm( + if (need_pack) { + int64_t psize = n / kvSplitSize * ekvSplitSize; + if constexpr (std::is_same_v) { + for (int64_t b = 0; b < headSize; b += packb_size) { + cpublas::brgemm( + qBlockSize, + packb_size, + ekvBlockSize, + ekvBlockSize, + packb_size, + rHeadSize, + 1.0, + n == 0 ? 0.f : 1.f, + qk_reduced_data, + value_reorder_ptr + + i * num_head * kv_padding_size * rHeadSize + + j * kv_padding_size * rHeadSize + psize * rHeadSize + + b * ekvBlockSize, + dst_data + b); + } + } + } else { + cpublas::gemm( TransposeType::NoTranspose, TransposeType::NoTranspose, headSize, @@ -451,6 +764,7 @@ void cpu_flash_attention( n == 0 ? static_cast(0) : static_cast(1), dst_data, headSize); + } } // dst <- dst / sum[row] @@ -465,7 +779,7 @@ void cpu_flash_attention( vec::map( [sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); }, out_data + i * oStrideB + j * oStrideH + m * oStrideM + row * oStrideM, - dst_data + row * headSize, + dst_data + row * rHeadSize, headSize); } // Store logsumexp for backward @@ -478,7 +792,9 @@ void cpu_flash_attention( data_index_step(i, batchSize, j, num_head, k, qSlice); } }); - + if (need_pack) { + cpublas::brgemm_release(); + } } template @@ -615,8 +931,7 @@ void cpu_flash_attention_backward( at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype)); accum_t* dsum_data = dsum.data_ptr(); - for (const auto z : c10::irange(begin, end)) { - (void)z; // Suppress unused variable + for (C10_UNUSED auto z : c10::irange(begin, end)) { // rowsum of grad_out * out for (int64_t m = 0; m < qSize; m += qSplitSize) { int64_t qBlockSize = std::min(qSplitSize, qSize - m); @@ -826,6 +1141,13 @@ void cpu_flash_attention_backward( AT_PRIVATE_CASE_TYPE_USING_HINT( \ at::ScalarType::Half, mask_t, __VA_ARGS__)) +#define FLASH_ATTENTION_KERNEL(FNAME, PACK, TYPE1, TYPE2, SEQ1, SEQ2, ...) \ + if (PACK) { \ + FNAME(__VA_ARGS__); \ + } else { \ + FNAME(__VA_ARGS__); \ + } + void flash_attention_kernel_impl( const Tensor& output, const Tensor& logsumexp, @@ -838,33 +1160,37 @@ void flash_attention_kernel_impl( std::optional scale) { auto q_seq_len = query.size(2); + // When q_seq_len and k_seq_len are long enough, + // cpu_flash_attention with pack has better performance. + bool could_pack = (query.scalar_type() == kHalf && cpublas::need_pack(kHalf)); + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, query.scalar_type(), "flash_attention", [&] { if (!attn_mask.has_value()) { if (q_seq_len >= 768) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 256, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else if (q_seq_len >= 192) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 64, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, scalar_t, 32, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } } else { AT_DISPATCH_MASK_TYPES(attn_mask.value().scalar_type(), "flash_attention_mask", [&]() { if (q_seq_len >= 768) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 256, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else if (q_seq_len >= 192) { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 64, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } else { - cpu_flash_attention( + FLASH_ATTENTION_KERNEL(cpu_flash_attention, could_pack, scalar_t, mask_t, 32, 512, output, logsumexp, query, key, value, dropout_p, is_causal, attn_mask, scale); } @@ -873,6 +1199,8 @@ void flash_attention_kernel_impl( }); } +#undef FLASH_ATTENTION_KERNEL + void flash_attention_backward_kernel_impl( const at::Tensor& grad_q, const at::Tensor& grad_k, diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp index 7eaac38c21c8ad..d8b4259775d968 100644 --- a/aten/src/ATen/native/cpu/LerpKernel.cpp +++ b/aten/src/ATen/native/cpu/LerpKernel.cpp @@ -19,7 +19,7 @@ Vectorized is_lerp_weight_small(Vectorized weight) { // is_lerp_weight_small doesn't work for complex because z.abs() returns a // complex vector which can't be compared. Either implement it with z.abs_2_(), // or fallback to the scalar function. -#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER)) +#if !(defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) || defined(CPU_CAPABILITY_SVE)) template Vectorized> is_lerp_weight_small(Vectorized> weight) { using vec_reg_t = decltype(weight.abs_2_()); diff --git a/aten/src/ATen/native/cpu/PaddingKernel.cpp b/aten/src/ATen/native/cpu/PaddingKernel.cpp index 302346c4515c9b..1aabb8a3d50d26 100644 --- a/aten/src/ATen/native/cpu/PaddingKernel.cpp +++ b/aten/src/ATen/native/cpu/PaddingKernel.cpp @@ -486,7 +486,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int cpu_padding(output, input, param); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad1d", [&] { cpu_padding(output, input, param); }); @@ -496,7 +496,7 @@ void reflection_pad1d_kernel_impl(const Tensor& output, const Tensor& input, Int void reflection_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -513,14 +513,14 @@ void reflection_pad2d_kernel_impl(const Tensor& output, const Tensor& input, Int } else { switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "reflection_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -537,14 +537,14 @@ void reflection_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "reflection_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -603,7 +603,7 @@ void reflection_pad3d_backward_kernel_impl( // replication padding void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, IntArrayRef padding) { PaddingParams param{input, output, padding}; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,input.scalar_type(), "replication_pad1d", [&] { cpu_padding(output, input, param); }); @@ -612,7 +612,7 @@ void replication_pad1d_kernel_impl(const Tensor& output, const Tensor& input, In void replication_pad1d_backward_kernel_impl( const Tensor& grad_input, const Tensor& grad_output, IntArrayRef padding) { PaddingParams param{grad_input, grad_output, padding}; - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad1d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); @@ -622,14 +622,14 @@ void replication_pad2d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (input.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad2d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad2d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -645,14 +645,14 @@ void replication_pad2d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (grad_output.suggest_memory_format()) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad2d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad2d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); @@ -667,14 +667,14 @@ void replication_pad3d_kernel_impl(const Tensor& output, const Tensor& input, In PaddingParams param{input, output, padding}; switch (padding_memory_format_3d(input)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad3d", [&] { cpu_padding(output, input, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, input.scalar_type(), + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, input.scalar_type(), "replication_pad3d_channels_last", [&]{ cpu_padding_channels_last(output, input, param); }); @@ -690,14 +690,14 @@ void replication_pad3d_backward_kernel_impl( PaddingParams param{grad_input, grad_output, padding}; switch (padding_memory_format_3d(grad_output)) { case at::MemoryFormat::Contiguous: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad3d_backward", [&] { cpu_padding_backward(grad_input, grad_output, param); }); break; } case at::MemoryFormat::ChannelsLast3d: { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, grad_output.scalar_type(), + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, grad_output.scalar_type(), "replication_pad3d_backward_channels_last", [&]{ cpu_padding_backward_channels_last(grad_input, grad_output, param); }); diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index 95119b5ac08580..6af22033c805e0 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -867,8 +867,8 @@ void scatter_reduce_expanded_index_kernel( } void gather_expanded_index_kernel(const Tensor& result, const Tensor& self, const Tensor& index) { - AT_DISPATCH_FLOATING_TYPES_AND( - ScalarType::BFloat16, self.scalar_type(), "gather_expanded_index", [&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "gather_expanded_index", [&] { cpu_gather_expanded_index_kernel(result, index, self); }); } diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 44722034ab3baa..b374935036dad9 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -325,7 +325,7 @@ static void isin_default_kernel_cpu( .check_all_same_dtype(false) .build(); // Dispatch based on promoted type. - AT_DISPATCH_ALL_TYPES(iter.dtype(1), "isin_default_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(1), "isin_default_cpu", [&]() { cpu_kernel(iter, [&](scalar_t element_val) -> bool { const auto* test_element_data = test_elements_flat.const_data_ptr(); for (const auto j : c10::irange(test_elements_flat.numel())) { diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index e74621bef67506..a558c1bf13139a 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -159,6 +159,12 @@ inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_ TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); } + +template <> +inline void transpose(int64_t M, int64_t N, const uint16_t* src, int64_t ld_src, uint16_t* dst, int64_t ld_dst) { + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); + fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); +} #endif template diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 991c7d2dba16db..741d05bdd7169e 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -964,9 +964,9 @@ ScalingType get_scaling_type( } // namespace -// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax +// Computes matrix multiply + bias while applying scaling to input and output matrices // Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. -// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. +// If output matrix type is 16 or 32-bit type, scale_result is not applied. // Known limitations: // - Only works if mat1 is row-major and mat2 is column-major // - Only works if matrices sizes are divisible by 32 @@ -1068,9 +1068,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, const auto out_dtype_ = args.result->scalar_type(); TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); - // Some scaled_gemms require an amax to populate lets create one here - Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float)); - #ifdef USE_ROCM auto tuning_ctx = at::cuda::tunable::getTuningContext(); if (tuning_ctx->IsTunableOpEnabled()) { @@ -1126,7 +1123,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr; params.ldc = args.result_ld; params.c_dtype = out_dtype_; - params.amax_ptr = amax.data_ptr(); params.use_fast_accum = use_fast_accum; if (transa_ && transb_) { TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) @@ -1150,11 +1146,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, else #endif { -#if defined(USE_ROCM) && ROCM_VERSION >= 60200 - // hipBlasLT requires scaleD to be set to something in order to use AMAX - auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA); - auto dummy_scale = at::ones(1, dummy_options); -#endif at::cuda::blas::scaled_gemm( args.transa, args.transb, @@ -1172,14 +1163,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, bias ? bias->data_ptr(): nullptr, bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, args.result->data_ptr(), -#if defined(USE_ROCM) && ROCM_VERSION >= 60200 - scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(), -#else scale_result ? scale_result->data_ptr() : nullptr, -#endif args.result_ld, out_dtype_, - amax.data_ptr(), use_fast_accum); } diff --git a/aten/src/ATen/native/cuda/Col2Im.cu b/aten/src/ATen/native/cuda/Col2Im.cu index bb6d4748deb1f5..4d8f4935039165 100644 --- a/aten/src/ATen/native/cuda/Col2Im.cu +++ b/aten/src/ATen/native/cuda/Col2Im.cu @@ -91,7 +91,7 @@ void col2im_out_cuda_template( if (input.dim() == 2) { // Force batch batched_input = false; - input = input.view({1, input.size(0), input.size(1)}); + input = input.unsqueeze(0); } int64_t batch_size = input.size(0); @@ -102,7 +102,7 @@ void col2im_out_cuda_template( output.resize_({batch_size, n_output_plane, output_height, output_width}); int64_t output_batch_stride = output.stride(0); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, input.scalar_type(), "col2im_out_cuda", [&] { int64_t height_col = (output_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / @@ -134,10 +134,10 @@ void col2im_out_cuda_template( output.mutable_data_ptr(), output_batch_stride); - if (!batched_input) { - output.resize_({n_output_plane, output_height, output_width}); - } }); + if (!batched_input) { + output = output.squeeze(0); + } } } // namespace diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index d7a118e6a9584c..1a969cfdbdcc43 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -237,7 +237,7 @@ void floating_half_bfloat16_(TensorList tensors) { OP_CUSTOM_FUNCTOR(function, op_name, functor_name); OP(floating_half_bfloat16, erfc, Erfc); -OP(floating_half, lgamma, Lgamma); +OP(floating_half_bfloat16, lgamma, Lgamma); OP(floating_half_bfloat16, trunc, Truncf); OP(floating_half_bfloat16, floor, Floor); OP(floating_half_bfloat16, ceil, Ceil); @@ -304,7 +304,7 @@ struct Sign { } }; -OP_CUSTOM_FUNCTOR(floating_half_bfloat16, sigmoid, Sigmoid) +OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, sigmoid, Sigmoid) OP_CUSTOM_FUNCTOR(floating_half_bfloat16, round, Round) OP_CUSTOM_FUNCTOR(floating_half_bfloat16, frac, Trunc) OP_CUSTOM_FUNCTOR(floating_complex_half_bfloat16, reciprocal, Reciprocal) diff --git a/aten/src/ATen/native/cuda/Im2Col.cu b/aten/src/ATen/native/cuda/Im2Col.cu index 312ad893c0d81e..d74a5f5f641a06 100644 --- a/aten/src/ATen/native/cuda/Im2Col.cu +++ b/aten/src/ATen/native/cuda/Im2Col.cu @@ -81,7 +81,7 @@ static void im2col_out_cuda_template( if (input.dim() == 3) { batched_input = false; - input = input.view({1, input.size(0), input.size(1), input.size(2)}); + input = input.unsqueeze(0); } int64_t batch_size = input.size(0); @@ -103,7 +103,7 @@ static void im2col_out_cuda_template( output.resize_({batch_size, n_output_plane, output_length}); // Launch kernel - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3(kHalf, kBFloat16, kBool, input.scalar_type(), "im2col_out_cuda", [&] { Tensor input_n; Tensor output_n; @@ -131,10 +131,10 @@ static void im2col_out_cuda_template( output_n.mutable_data_ptr()); } - if (!batched_input) { - output.resize_({n_output_plane, output_length}); - } }); + if (!batched_input) { + output = output.squeeze(0); + } } } // namespace diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 613ecf9b9d3f98..ee83ee5c6d3b8a 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -1350,7 +1350,7 @@ template void index_select_out_cuda_impl( Tensor& out, const Tensor& self, - long dim, + uint64_t dim, const Tensor& index) { uint64_t numIndices = index.numel(); uint64_t selfDims = self.dim() == 0 ? 1 : self.dim(); @@ -1511,13 +1511,13 @@ Tensor& index_select_out_cuda( self.qscheme() == kPerTensorAffine, "Only per_tensor quantized quantized tensors are supported by index_select.") AT_DISPATCH_QINT_TYPES(out.scalar_type(), "index_select_quant_cuda", [&] { - index_select_out_cuda_impl(out, self, dim, index); + index_select_out_cuda_impl(out, self, (uint64_t) dim, index); }); } else { AT_DISPATCH_V2( out.scalar_type(), "index_select_cuda", - AT_WRAP([&] { index_select_out_cuda_impl(out, self, dim, index); }), + AT_WRAP([&] { index_select_out_cuda_impl(out, self, (uint64_t) dim, index); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kComplexHalf, kHalf, diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index e87f46cd844eab..e5fb9230de7637 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +#include #include #include #include @@ -70,7 +71,16 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ auto temp_storage = allocator.allocate(temp_storage_bytes); cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); int num_nonzeros_h; - at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); + auto pinned_num_nonzeros_h = at::detail::empty_cpu( + {1}, /* size */ + c10::CppTypeToScalarType(), /* dtype */ + std::nullopt, /* layout */ + std::nullopt, /* device */ + true, /* pin_memory */ + std::nullopt /* memory format */ + ); + at::cuda::memcpy_and_sync((void *)pinned_num_nonzeros_h.const_data_ptr(), num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); + num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr()); //expected output size is num_nonzeros x ndim //we are producing output with size {num_nonzeros, ndim} and strides {1, num_nonzeros} (that is, transposed ndim x num_nonzeros output) //we are able to directly use passed output with this size and strides, and we can also (per contract) diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 79b23ea8afd4fc..7a25975b624b04 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1092,7 +1092,11 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ } constexpr int min_values_per_thread = 16; +#ifndef USE_ROCM constexpr int max_values_per_thread = 256; +#else + constexpr int max_values_per_thread = 1024; +#endif if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) { // Divide the input across warps in a thread-block, if that leaves at least diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 6c3e7075bc6a8f..f76d6bfb66a727 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -69,6 +69,8 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( namespace { +constexpr int kNumSMsForH100 = 132; + using DtypeScale = float; using DtypeAccum = float; using DtypeEpilogue = float; @@ -115,11 +117,20 @@ struct Schedule { using type = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; }; +int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +int round_up_to_nearest_multiple(int a, int b) { + return ceildiv(a, b) * b; +} + // Cutlass rowwise kernel template < typename TileShape, typename ClusterShape, typename PingPong, + typename Transposed, typename FastAccum, typename DtypeA, typename DtypeB, @@ -150,7 +161,10 @@ void f8f8bf16_rowwise_impl( using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); - using LayoutOutput = cutlass::layout::RowMajor; + using LayoutOutput = std::conditional_t< + Transposed::value, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); // Tag indicating the minimum SM that supports the intended feature @@ -167,8 +181,12 @@ void f8f8bf16_rowwise_impl( using WScale = cutlass::epilogue::fusion:: Sm90RowBroadcast; - using Bias = cutlass::epilogue::fusion:: - Sm90RowBroadcast; + using Bias = std::conditional_t< + Transposed::value, + cutlass::epilogue::fusion:: + Sm90ColBroadcast, + cutlass::epilogue::fusion:: + Sm90RowBroadcast>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; @@ -284,27 +302,7 @@ void f8f8bf16_rowwise_impl( C10_CUDA_KERNEL_LAUNCH_CHECK(); } -// FP8 Rowwise Cutlass kernel dispatch. -enum class KernelMode { Small, Large, Default }; - -KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { - auto M = XQ.size(0); - auto K = XQ.size(1); - auto N = WQ.size(0); - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || - (K >= 2048 && N >= 2048)); - if (M <= 128 || N <= 128) { - return KernelMode::Small; - } else if (use_large_kernel) { - return KernelMode::Large; - } else { - return KernelMode::Default; - } -} - -template +template void dispatch_fp8_rowwise_kernel_on_tile_size( at::Tensor XQ, at::Tensor WQ, @@ -312,24 +310,143 @@ void dispatch_fp8_rowwise_kernel_on_tile_size( at::Tensor w_scale, std::optional bias, at::Tensor out) { - KernelMode kernel = get_kernel_mode(XQ, WQ); - if (kernel == KernelMode::Small) { + int M = XQ.size(0); + int N = WQ.size(1); + + // We prefer to use smaller tiles (less wasted compute in case of padding), + // but if this causes us to have more CUDA blocks than there are SMs on the + // GPU then we'll hit wave quantization, hence we'll switch to larger tiles. + if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) * + ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <= + kNumSMsForH100 / cute::size(ClusterShape{})) { return f8f8bf16_rowwise_impl< /*TileShape=*/cute::Shape, - /*ClusterShape=*/cute::Shape, + ClusterShape, /*PingPong=*/std::false_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (kernel == KernelMode::Large) { + } else { return f8f8bf16_rowwise_impl< /*TileShape=*/cute::Shape, - /*ClusterShape=*/cute::Shape, + ClusterShape, /*PingPong=*/std::true_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } +} + +template < + typename ClusterShape, + typename Transposed, + typename FastAccum, + typename DtypeA, + typename DtypeB, + typename DtypeBias> +void handle_transposition( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + at::Tensor out) { + if constexpr (!Transposed::value) { + dispatch_fp8_rowwise_kernel_on_tile_size< + ClusterShape, + Transposed, + FastAccum, + DtypeA, + DtypeB, + DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out); } else { - return f8f8bf16_rowwise_impl< - /*TileShape=*/cute::Shape, + dispatch_fp8_rowwise_kernel_on_tile_size< + ClusterShape, + Transposed, + FastAccum, + DtypeB, + DtypeA, + DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t()); + } +} + +template +void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + std::optional bias, + at::Tensor out) { + int M = XQ.size(0); + int N = WQ.size(1); + + // All the tiles we use have sizes which are multiples of 64, hence any + // non-multiple of 64 will get padded anyways. Let's round up to simplify. + M = round_up_to_nearest_multiple(M, 64); + N = round_up_to_nearest_multiple(N, 64); + + // Small/skinny shapes with odd multiples of 64. + if (M == 64 && N >= 3072) { + return handle_transposition< /*ClusterShape=*/cute::Shape, - /*PingPong=*/std::false_type, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + if (N == 64 && M >= 3072) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + if (M == 192 && N >= 4096) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + if (N == 192 && M >= 4096) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + + // Now to odd multiples of 128 (but only if not too large). + if (M * N <= 4096 * 4096) { + if (M % 256 > 0 && N % 256 == 0) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + if (N % 256 > 0 && M % 256 == 0) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + if (M % 256 > 0 && N % 256 > 0) { + if ((M <= N) ^ (M * N <= 1024 * 1024)) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } + } + + // General case for large tensors. + if ((M <= N) ^ (M >= 2048 && N >= 2048)) { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); + } else { + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, Types...>(XQ, WQ, x_scale, w_scale, bias, out); } } @@ -344,11 +461,13 @@ void dispatch_fp8_rowwise_kernel_on_fast_accum( bool use_fast_accum, at::Tensor out) { if (use_fast_accum) { - dispatch_fp8_rowwise_kernel_on_tile_size( - XQ, WQ, x_scale, w_scale, bias, out); + dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< + std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); } else { - dispatch_fp8_rowwise_kernel_on_tile_size( - XQ, WQ, x_scale, w_scale, bias, out); + dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< + std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out); } } diff --git a/aten/src/ATen/native/cuda/Sorting.cpp b/aten/src/ATen/native/cuda/Sorting.cpp index 9381c0e4f02a4c..310db22aa0571f 100644 --- a/aten/src/ATen/native/cuda/Sorting.cpp +++ b/aten/src/ATen/native/cuda/Sorting.cpp @@ -21,6 +21,9 @@ #include #include #include +#include +#include +#include #endif namespace at::native { @@ -146,8 +149,8 @@ Tensor median_impl(const Tensor& self, bool ignore_nan) { return at::where(sorted[-1].isnan(), sorted[-1], sorted[k]); } else { // For torch.nanmedian return the middle element among the non-nan values - int64_t k = ((size - 1) - sorted.isnan().sum().item()) / 2; - return sorted[k].clone(); // Clone so we aren't keeping `sorted` alive + Tensor k = at::div(at::rsub(sorted.isnan().sum(), (size - 1)), 2).to(kLong); + return at::index(sorted, {k}); } } diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 6272bbb9b75df5..290be3926c6ffc 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -177,12 +177,11 @@ struct KthValueLauncher { cuda::detail::TensorInfo values_info, int collapse_values_dim, cuda::detail::TensorInfo indices_info, - int collapse_indices_dim, + C10_UNUSED int collapse_indices_dim, cuda::detail::TensorInfo self_info, int collapse_self_dim, int64_t num_slices, int64_t slice_size) { - (void)collapse_indices_dim; // Suppress unused variable warning dim3 grid; if (!getGridFromTiles(num_slices, grid)) { AT_ERROR("slices are too many"); @@ -213,15 +212,13 @@ struct MedianLauncher { template inline void launch( cuda::detail::TensorInfo values_info, - int collapse_values_dim, + C10_UNUSED int collapse_values_dim, cuda::detail::TensorInfo indices_info, - int collapse_indices_dim, + C10_UNUSED int collapse_indices_dim, cuda::detail::TensorInfo self_info, int collapse_self_dim, int64_t num_slices, int64_t slice_size) { - (void)collapse_values_dim; // Suppress unused variable warning - (void)collapse_indices_dim; // Suppress unused variable warning dim3 grid; if (!getGridFromTiles(num_slices, grid)) { AT_ERROR("slices are too many"); diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index f6956405174261..1a3ee09ac931c3 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -101,33 +101,41 @@ REGISTER_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl); REGISTER_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl); REGISTER_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl); +struct Msg { + static constexpr size_t MAX_MSG_LENGTH = 256; + char msg[MAX_MSG_LENGTH]; +}; template -__global__ void _assert_async_cuda_kernel(const scalar_t* input) { - CUDA_KERNEL_ASSERT(input[0] != 0); +__global__ void _assert_async_cuda_kernel(const scalar_t* input, Msg msg) { + CUDA_KERNEL_ASSERT_MSG(input[0] != 0, msg.msg); } -__global__ void _assert_async_cuda_kernel(const c10::complex* input) { - CUDA_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +__global__ void _assert_async_cuda_kernel(const c10::complex* input, Msg msg) { + CUDA_KERNEL_ASSERT_MSG(input[0] != c10::complex(0, 0), msg.msg); } -__global__ void _assert_async_cuda_kernel(const c10::complex* input) { - CUDA_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +__global__ void _assert_async_cuda_kernel(const c10::complex* input, Msg msg) { + CUDA_KERNEL_ASSERT_MSG(input[0] != c10::complex(0, 0), msg.msg); } -void _assert_async_cuda(const Tensor& self_tensor) { +void _assert_async_msg_cuda(const Tensor& self_tensor, c10::string_view assert_msg) { const TensorBase &self = get_tensor_base(self_tensor); auto n = self.numel(); TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous"); TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous"); auto stream = at::cuda::getCurrentCUDAStream(); + Msg msg; + size_t copy_length = assert_msg.length(); + TORCH_CHECK(copy_length < Msg::MAX_MSG_LENGTH - 1, "Message length must be smaller than " + std::to_string(Msg::MAX_MSG_LENGTH - 1)); + std::copy_n(assert_msg.data(), copy_length, msg.msg); + msg.msg[copy_length] = '\0'; // Ensure null-termination AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_assert_async_cuda", [&] { - _assert_async_cuda_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr()); + _assert_async_cuda_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr(), msg); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } -// TODO (tmanlaibaatar) Ignore assert msg for now -void _assert_async_msg_cuda(const Tensor& self_tensor, c10::string_view assert_msg) { - _assert_async_cuda(self_tensor); +void _assert_async_cuda(const Tensor& self_tensor) { + _assert_async_msg_cuda(self_tensor, ""); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/Unique.cu b/aten/src/ATen/native/cuda/Unique.cu index 39e80e0a68c3c8..67eeacee29a850 100644 --- a/aten/src/ATen/native/cuda/Unique.cu +++ b/aten/src/ATen/native/cuda/Unique.cu @@ -191,7 +191,7 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) { // lack of hashtable implementation in thrust auto [output, inverse, _] = internal::unique_cuda_template(self, false, return_inverse, false); return std::make_tuple(output, inverse); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple @@ -200,21 +200,21 @@ _unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust return internal::unique_cuda_template(self, false, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] { return unique_dim_cuda_template(self, dim, false, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple unique_dim_consecutive_cuda(const Tensor& self, const int64_t dim, const bool return_inverse, const bool return_counts) { return AT_DISPATCH_V2(self.scalar_type(), "unique_dim", AT_WRAP([&] { return unique_dim_cuda_template(self, dim, true, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } std::tuple @@ -224,7 +224,7 @@ unique_consecutive_cuda(const Tensor& self, const bool return_inverse, const boo // The current CUDA implementation of unique always sort due to the // lack of hashtable implementation in thrust return internal::unique_cuda_template(self, true, return_inverse, return_counts); - }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES), kBool, kBFloat16, kHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } return unique_dim_consecutive_cuda(self, dim.value(), return_inverse, return_counts); } diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu index bbd8673bcf5a65..1bda65815d6d49 100644 --- a/aten/src/ATen/native/cuda/UniqueCub.cu +++ b/aten/src/ATen/native/cuda/UniqueCub.cu @@ -339,6 +339,7 @@ INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint32_t); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint64_t); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(uint16_t); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(bool); +INSTANTIATE_UNIQUE_CUDA_TEMPLATE(BFloat16); INSTANTIATE_UNIQUE_CUDA_TEMPLATE(at::Half); #undef INSTANTIATE diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index d00a8eb30698b8..c70a96f937cb81 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -22,6 +22,7 @@ void run_cudnn_SDP_fprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -43,6 +44,7 @@ void run_cudnn_SDP_bprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, @@ -86,9 +88,9 @@ using graph_and_tensors = std::tuple< std::shared_ptr, // Q, std::shared_ptr, // K, std::shared_ptr, // V, + std::optional>, // Bias std::shared_ptr, // Attn_scale, // TODO(eqy): additional options - // std::shared_ptr, // Bias, // std::shared_ptr, // SEQ_LEN_Q, // std::shared_ptr, // SEQ_LEN_KV, std::shared_ptr, // Seed, @@ -104,7 +106,8 @@ using graph_and_tensors_backward = std::tuple< std::shared_ptr, // Q, std::shared_ptr, // K, std::shared_ptr, // V, - std::shared_ptr, // Attn_scale + std::optional>, // Bias, + std::shared_ptr, // Attn_scale, std::shared_ptr, // Seed, std::shared_ptr, // Offset, std::shared_ptr, // O, @@ -126,6 +129,8 @@ struct MHAParams { std::array q_stride; std::array k_stride; std::array v_stride; + std::array bias_dim; + std::array bias_stride; int64_t b; int64_t h; int64_t s_q; @@ -135,6 +140,9 @@ struct MHAParams { double dropout_probability; bool is_causal; bool return_softmaxstats; + // might be redundant if we take 0 dim/stride + // as signaling no-bias + bool has_attn_bias; }; void setMHAParams( @@ -148,6 +156,7 @@ void setMHAParams( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, double dropout_probability, bool is_causal, bool return_softmaxstats) { @@ -166,6 +175,7 @@ void setMHAParams( params.dropout_probability = dropout_probability; params.is_causal = is_causal; params.return_softmaxstats = return_softmaxstats; + params.has_attn_bias = attn_bias.has_value(); TORCH_INTERNAL_ASSERT( q.sizes().size() == MAX_MHA_DIM, "Q tensor has unexpected number of dims, please report a bug to PyTorch."); @@ -190,6 +200,17 @@ void setMHAParams( std::copy(k.strides().begin(), k.strides().end(), params.k_stride.begin()); std::copy(v.sizes().begin(), v.sizes().end(), params.v_dim.begin()); std::copy(v.strides().begin(), v.strides().end(), params.v_stride.begin()); + // uninit is OK as the struct is memset 0'd + if (params.has_attn_bias) { + std::copy( + attn_bias.value().sizes().begin(), + attn_bias.value().sizes().end(), + params.bias_dim.begin()); + std::copy( + attn_bias.value().strides().begin(), + attn_bias.value().strides().end(), + params.bias_stride.begin()); + } } struct MHACacheKeyWrapper : ParamsWrapper { @@ -203,6 +224,7 @@ struct MHACacheKeyWrapper : ParamsWrapper { const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, double dropout_probability, bool is_causal, bool return_softmaxstats) { @@ -217,6 +239,7 @@ struct MHACacheKeyWrapper : ParamsWrapper { q, k, v, + attn_bias, dropout_probability, is_causal, return_softmaxstats); @@ -285,6 +308,7 @@ auto build_graph_and_tensors( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -301,36 +325,6 @@ auto build_graph_and_tensors( mha_graph->set_io_data_type(dtype) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - auto Q = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(std::vector( - q.sizes().data(), q.sizes().data() + q.sizes().size())) - .set_stride(fixSizeOneDimStrideSDPA( - q.sizes(), - std::vector( - q.strides().data(), - q.strides().data() + q.strides().size())))); - auto K = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(std::vector( - k.sizes().data(), k.sizes().data() + k.sizes().size())) - .set_stride(fixSizeOneDimStrideSDPA( - k.sizes(), - std::vector( - k.strides().data(), - k.strides().data() + k.strides().size())))); - auto V = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(std::vector( - v.sizes().data(), v.sizes().data() + v.sizes().size())) - .set_stride(fixSizeOneDimStrideSDPA( - v.sizes(), - std::vector( - v.strides().data(), - v.strides().data() + v.strides().size())))); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Attn_scale") @@ -338,11 +332,6 @@ auto build_graph_and_tensors( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - // TODO(eqy): support bias in the future in a follow-up PR - // auto bias = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("bias") - // .set_dim({b, 1, s_q, s_kv}) - // .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seed") .set_dim({1, 1, 1, 1}) @@ -360,11 +349,30 @@ auto build_graph_and_tensors( .set_causal_mask(is_causal) .set_attn_scale(attn_scale) .set_dropout(dropout_probability, seed, offset); - // Optional bias in flash attention is only supported 8.9.3 onwards - if (cudnnGetVersion() >= 8904) { - // scaled_dot_product_flash_attention_options.set_alibi_mask(true); + auto Q = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); + auto K = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); + auto V = mha_graph->tensor( + fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); + std::optional> bias; + if (attn_bias.has_value()) { + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seq_q") .set_dim({b, 1, 1, 1}) @@ -376,20 +384,9 @@ auto build_graph_and_tensors( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); - // if (cudnnGetVersion() >= 8903) { - // scaled_dot_product_flash_attention_options.set_bias(bias) - // .set_padding_mask(true) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } - auto [O, Stats] = mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); - O->set_output(true) - .set_dim(std::vector( - o.sizes().data(), o.sizes().data() + o.sizes().size())) - .set_stride(std::vector( - o.strides().data(), o.strides().data() + o.strides().size())); + O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); @@ -407,6 +404,7 @@ auto build_graph_and_tensors( std::move(Q), std::move(K), std::move(V), + std::move(bias), std::move(attn_scale), std::move(seed), std::move(offset), @@ -427,6 +425,7 @@ auto build_graph_and_tensors_backward( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, @@ -447,24 +446,6 @@ auto build_graph_and_tensors_backward( mha_graph->set_io_data_type(dtype) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - auto Q = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim(std::vector(q.sizes().begin(), q.sizes().end())) - .set_stride( - std::vector(q.strides().begin(), q.strides().end()))); - auto K = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("K") - .set_dim(std::vector(k.sizes().begin(), k.sizes().end())) - .set_stride( - std::vector(k.strides().begin(), k.strides().end()))); - auto V = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("V") - .set_dim(std::vector(v.sizes().begin(), v.sizes().end())) - .set_stride( - std::vector(v.strides().begin(), v.strides().end()))); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Attn_scale") @@ -472,6 +453,31 @@ auto build_graph_and_tensors_backward( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); + auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() + .set_name("CUDNN_SDPA_BACKWARD") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); + std::optional> bias; + if (attn_bias.has_value()) { + bias = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim(attn_bias.value().sizes().vec()) + .set_stride(attn_bias.value().strides().vec())); + sdpa_backward_options.set_bias(bias.value()); + } auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seed") .set_dim({1, 1, 1, 1}) @@ -482,47 +488,27 @@ auto build_graph_and_tensors_backward( .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); - auto O = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("O") - .set_dim(std::vector(o.sizes().begin(), o.sizes().end())) - .set_stride( - std::vector(o.strides().begin(), o.strides().end()))); - auto STATS = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("Stats") - .set_dim(std::vector( - softmaxstats.sizes().begin(), softmaxstats.sizes().end())) - .set_stride(std::vector( - softmaxstats.strides().begin(), softmaxstats.strides().end())) - .set_data_type(fe::DataType_t::FLOAT)); - auto DO = mha_graph->tensor( - fe::graph::Tensor_attributes() - .set_name("DO") - .set_dim(std::vector(dO.sizes().begin(), dO.sizes().end())) - .set_stride( - std::vector(dO.strides().begin(), dO.strides().end()))); - auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("CUDNN_SDPA_BACKWARD") - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); + auto O = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(softmaxstats.sizes().vec()) + .set_stride(softmaxstats.strides().vec()) + .set_data_type(fe::DataType_t::FLOAT)); + auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("DO") + .set_dim(dO.sizes().vec()) + .set_stride(dO.strides().vec())); if (dropout_probability != 0.0f) { sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); } auto [DQ, DK, DV] = mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); - DQ->set_output(true) - .set_dim(std::vector(dQ.sizes().begin(), dQ.sizes().end())) - .set_stride( - std::vector(dQ.strides().begin(), dQ.strides().end())); - DK->set_output(true) - .set_dim(std::vector(dK.sizes().begin(), dK.sizes().end())) - .set_stride( - std::vector(dK.strides().begin(), dK.strides().end())); - DV->set_output(true) - .set_dim(std::vector(dV.sizes().begin(), dV.sizes().end())) - .set_stride( - std::vector(dV.strides().begin(), dV.strides().end())); + DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( @@ -534,6 +520,7 @@ auto build_graph_and_tensors_backward( std::move(Q), std::move(K), std::move(V), + std::move(bias), std::move(attn_scale), std::move(Seed), std::move(Offset), @@ -559,6 +546,7 @@ void run_cudnn_SDP_fprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -573,6 +561,11 @@ void run_cudnn_SDP_fprop( softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); } + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel()) { + return; + } + auto key = MHACacheKeyWrapper( b, h, @@ -583,6 +576,7 @@ void run_cudnn_SDP_fprop( q, k, v, + attn_bias, dropout_probability, is_causal, return_softmaxstats); @@ -605,13 +599,14 @@ void run_cudnn_SDP_fprop( q, k, v, + attn_bias, softmaxstats, o, dropoutseed, dropoutoffset, handle); } - auto [mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats] = + auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = graph_and_tensors_values; std::unordered_map, void*> variant_pack = { @@ -619,13 +614,15 @@ void run_cudnn_SDP_fprop( {K, k.data_ptr()}, {V, v.data_ptr()}, {attn_scale, &scaling_factor}, - //{bias, bias.data_ptr()}, {seed, dropoutseed.data_ptr()}, {offset, dropoutoffset.data_ptr()}, {O, o.data_ptr()}}; if (return_softmaxstats) { variant_pack[Stats] = softmaxstats.data_ptr(); } + if (attn_bias.has_value()) { + variant_pack[bias.value()] = attn_bias.value().data_ptr(); + } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); @@ -647,6 +644,7 @@ void run_cudnn_SDP_bprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, @@ -655,6 +653,12 @@ void run_cudnn_SDP_bprop( Tensor& dV, const Tensor& dropoutseed, const Tensor& dropoutoffset) { + // do nothing if we got 0-element tensors + if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || + !softmaxstats.numel()) { + return; + } + Tensor dO_ = dO; if (!dO.strides()[dO.strides().size() - 1]) { TORCH_WARN( @@ -694,6 +698,7 @@ void run_cudnn_SDP_bprop( q, k, v, + attn_bias, dropout_probability, is_causal, true); @@ -715,6 +720,7 @@ void run_cudnn_SDP_bprop( q, k, v, + attn_bias, o, dO_, softmaxstats, @@ -726,8 +732,20 @@ void run_cudnn_SDP_bprop( handle); } auto - [mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] = - graph_and_tensors_backward_values; + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + Seed, + Offset, + O, + Do, + Stats, + Dq, + Dk, + Dv] = graph_and_tensors_backward_values; std::unordered_map, void*> variant_pack = {// inputs {Q, q.data_ptr()}, @@ -746,6 +764,9 @@ void run_cudnn_SDP_bprop( variant_pack[Seed] = dropoutseed.data_ptr(); variant_pack[Offset] = dropoutoffset.data_ptr(); } + if (attn_bias.has_value()) { + variant_pack[bias.value()] = attn_bias.value().data_ptr(); + } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 8b9315a5a3d85c..3ae1a03b2a7741 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -18,6 +18,7 @@ void run_cudnn_SDP_fprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, Tensor& softmaxstats, Tensor& o, Tensor& dropoutseed, @@ -36,6 +37,7 @@ void run_cudnn_SDP_bprop( const Tensor& q, const Tensor& k, const Tensor& v, + const std::optional& attn_bias, const Tensor& o, const Tensor& dO, const Tensor& softmaxstats, diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 22410e8690ce65..627fa71382e209 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -197,8 +197,8 @@ Tensor group_norm( const Tensor kEmpty; auto memory_format = input.suggest_memory_format(); - const auto& X = input.device().is_cpu() || input.device().is_xpu() ? - input.contiguous(memory_format) : input.contiguous(); + const auto& X = input.device().is_cpu() || input.is_privateuseone() ? + input.contiguous(memory_format) : input.contiguous(); const auto& gamma = weight.defined() ? weight.contiguous() : kEmpty; const auto& beta = bias.defined() ? bias.contiguous() : kEmpty; TORCH_CHECK(!gamma.defined() || gamma.sym_numel() == C); diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index b11bcaba38e689..c739547af9c1ab 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -263,18 +264,15 @@ std::tuple math_native_layer_norm( return std::make_tuple(out, mean, rstd); } -Tensor rms_norm( +Tensor rms_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, + c10::SymIntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { - // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - auto bias_opt = std::optional(); - const Tensor& bias = *at::borrow_from_optional_tensor(bias_opt); - (void) _check_layer_norm_inputs(input, normalized_shape, weight, bias); + _check_rms_norm_inputs_symint(input, normalized_shape, weight); std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -295,7 +293,12 @@ Tensor rms_norm( eps_val = eps.value(); } - auto result = input.mul(at::rsqrt(at::pow(input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val))); + // upcast is needed for fp16 and bf16 + c10::ScalarType opmath_t = toOpMathType(input.scalar_type()); + Tensor upcasted_input = input.to(opmath_t); + + Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val)); + Tensor result = upcasted_input.mul(rqrst_input).type_as(input); if (weight_opt.has_value()) { result = result.mul(weight_opt.value()); diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index e35ccf8634bccb..ba2b356c0b0459 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -8,6 +8,41 @@ namespace at::native { namespace { +C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const Tensor& weight /* optional */) { + + const int normalized_ndim = normalized_shape.size(); + TORCH_CHECK( + normalized_ndim >= 1, + "Expected normalized_shape to be at least 1-dimensional, i.e., ", + "containing at least one element, but got normalized_shape = ", + normalized_shape); + TORCH_CHECK( + !weight.defined() || weight.sym_sizes().equals(normalized_shape), + "Expected weight to be of same shape as normalized_shape, but got ", + "weight of shape ", + weight.sym_sizes(), + " and normalized_shape = ", + normalized_shape); + + const auto input_ndim = input.dim(); + const auto input_shape = input.sym_sizes(); + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim) + .equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + AT_ERROR(ss.str()); + } +} + C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs( const Tensor& input, IntArrayRef normalized_shape, @@ -71,9 +106,9 @@ void layer_norm_cpu_out( int64_t M, int64_t N); -Tensor rms_norm( +Tensor rms_norm_symint( const Tensor& input, - IntArrayRef normalized_shape, + c10::SymIntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps); diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h index e1a9b2617bd3e2..639e4b6b746b7c 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h @@ -2,10 +2,10 @@ @interface MPSCNNNeuronOp : NSObject -+ (MPSCNNNeuronHardSigmoid*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)); -+ (MPSCNNNeuronReLU*)relu; -+ (MPSCNNNeuronSigmoid*)sigmoid; -+ (MPSCNNNeuronTanH*)tanh; ++ (MPSCNNNeuron*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)); ++ (MPSCNNNeuron*)relu; ++ (MPSCNNNeuron*)sigmoid; ++ (MPSCNNNeuron*)tanh; @end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm index e722f2765c0893..62eb51661a9a62 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm @@ -8,69 +8,67 @@ @implementation MPSCNNNeuronOp -+ (MPSCNNNeuronHardSigmoid*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)) { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else ++ (MPSCNNNeuron*)hardSigmoid API_AVAILABLE(ios(11.0), macos(10.13)) { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; - static MPSCNNNeuronHardSigmoid* neuron = nil; dispatch_once(&onceToken, ^{ +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] initWithDevice:[MetalContext sharedInstance].device neuronDescriptor:[MPSCNNNeuronOpDescriptor hardSigmoidDescriptor]]; +#else neuron = [[MPSCNNNeuronHardSigmoid alloc] - initWithDevice:[MetalContext sharedInstance].device - a:1.0 / 6.0 - b:0.5]; + initWithDevice:[MetalContext sharedInstance].device + a:1.0 / 6.0 + b:0.5]; +#endif }); return neuron; -#endif } -+ (MPSCNNNeuronReLU*)relu { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else - static MPSCNNNeuronReLU* relu = nil; ++ (MPSCNNNeuron*)relu { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; dispatch_once(&onceToken, ^{ - relu = [[MPSCNNNeuronReLU alloc] - initWithDevice:[MetalContext sharedInstance].device - a:0]; - }); - return relu; +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] + initWithDevice:[MetalContext sharedInstance].device + neuronDescriptor:[MPSCNNNeuronOpDescriptor reluDescriptor]]; +#else + neuron = [[MPSCNNNeuronReLU alloc] + initWithDevice:[MetalContext sharedInstance].device + a:0]; #endif + }); + return neuron; } -+ (MPSCNNNeuronSigmoid*)sigmoid { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else ++ (MPSCNNNeuron*)sigmoid { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; - static MPSCNNNeuronSigmoid* sigmoid = nil; dispatch_once(&onceToken, ^{ - sigmoid = [[MPSCNNNeuronSigmoid alloc] - initWithDevice:[MetalContext sharedInstance].device]; - }); - return sigmoid; +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] initWithDevice:[MetalContext sharedInstance].device neuronDescriptor:[MPSCNNNeuronOpDescriptor sigmoidDescriptor]]; +#else + neuron = [[MPSCNNNeuronSigmoid alloc] + initWithDevice:[MetalContext sharedInstance].device]; #endif + }); + return neuron; } -+ (MPSCNNNeuronTanH*)tanh { -// Remove this once we support iOS 11.3 -#if TARGET_OS_MACCATALYST - return nil; -#else ++ (MPSCNNNeuron*)tanh { + static MPSCNNNeuron* neuron = nil; static dispatch_once_t onceToken; - static MPSCNNNeuronTanH* tanh = nil; dispatch_once(&onceToken, ^{ - tanh = [[MPSCNNNeuronTanH alloc] - initWithDevice:[MetalContext sharedInstance].device - a:1 - b:1]; - }); - return tanh; +#if TARGET_OS_MACCATALYST + neuron = [[MPSCNNNeuron alloc] initWithDevice:[MetalContext sharedInstance].device neuronDescriptor:[MPSCNNNeuronOpDescriptor tanhDescriptor]]; +#else + neuron = [[MPSCNNNeuronTanH alloc] + initWithDevice:[MetalContext sharedInstance].device + a:1 + b:1]; #endif + }); + return neuron; } @end @@ -85,9 +83,9 @@ + (MPSNNNeuronDescriptor*)hardSigmoidDescriptor { static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ neuronDesc = [MPSNNNeuronDescriptor - cnnNeuronDescriptorWithType:MPSCNNNeuronTypeHardSigmoid - a:1.0 / 6.0 - b:0.5]; + cnnNeuronDescriptorWithType:MPSCNNNeuronTypeHardSigmoid + a:1.0 / 6.0 + b:0.5]; }); return neuronDesc; } @@ -97,8 +95,8 @@ + (MPSNNNeuronDescriptor*)reluDescriptor { static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ neuronDesc = - [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeReLU - a:0]; + [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeReLU + a:0]; }); return neuronDesc; } @@ -108,7 +106,7 @@ + (MPSNNNeuronDescriptor*)sigmoidDescriptor { static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ neuronDesc = [MPSNNNeuronDescriptor - cnnNeuronDescriptorWithType:MPSCNNNeuronTypeSigmoid]; + cnnNeuronDescriptorWithType:MPSCNNNeuronTypeSigmoid]; }); return neuronDesc; } @@ -117,10 +115,9 @@ + (MPSNNNeuronDescriptor*)tanhDescriptor { static dispatch_once_t onceToken; static MPSNNNeuronDescriptor* neuronDesc = nil; dispatch_once(&onceToken, ^{ - neuronDesc = - [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeTanH - a:1.0 - b:1.0]; + neuronDesc = [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeTanH + a:1.0 + b:1.0]; }); return neuronDesc; } diff --git a/aten/src/ATen/native/mkl/LinearAlgebra.cpp b/aten/src/ATen/native/mkl/LinearAlgebra.cpp index 582de5f8213a03..b64ccfeb03feae 100644 --- a/aten/src/ATen/native/mkl/LinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/LinearAlgebra.cpp @@ -154,4 +154,4 @@ void mkl_gemm_f16f16f32( }} // namespace at::native #endif -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index e26cfbf6d8ebaa..8ae620ed0028c1 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -165,6 +165,7 @@ REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_co REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_SVE256_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) // _out variants can be shared between PocketFFT and MKL Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index a63d9ebfa2c150..2f3c791914e3db 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -89,10 +89,22 @@ const std::map& fusion_binary_alg_map(); inline bool mkldnn_bf16_device_check_arm() { return cpuinfo_initialize() && cpuinfo_has_arm_bf16(); } + +inline bool is_arm_neoverse() { + return (cpuinfo_initialize() && cpuinfo_get_uarchs_count() == 1 && + (cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v1 || + cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_v2 || + cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_n1 || + cpuinfo_get_uarch(0)->uarch == cpuinfo_uarch_neoverse_n2)); +} #else constexpr bool mkldnn_bf16_device_check_arm() { return false; } + +constexpr bool is_arm_neoverse() { + return false; +} #endif #if AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index d588e03afbb4f7..db469733fcf3b9 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -227,6 +227,10 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return "uchar"; case ScalarType::Bool: return "bool"; + case ScalarType::ComplexHalf: + return "half2"; + case ScalarType::ComplexFloat: + return "float2"; default: TORCH_CHECK(false, "Undefined type ", scalar_type); return "Undefined"; diff --git a/aten/src/ATen/native/mps/operations/Im2Col.mm b/aten/src/ATen/native/mps/operations/Im2Col.mm new file mode 100644 index 00000000000000..5fd3f0d9ac36d5 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Im2Col.mm @@ -0,0 +1,191 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { +using namespace mps; +static MetalShaderLibrary lib(R"IM2COL_METAL( +// Heavily inspired by https://github.com/pytorch/pytorch/blob/09519eb19/aten/src/ATen/native/cuda/im2col.cuh#L51 +template +void im2col_kernel( + constant T * input, + device T * output, + uint2 kernel_size, + long2 input_offset, + long2 input_size, + long2 dilation, + ulong2 input_strides, + ulong output_stride) { + for (ulong i = 0; i < kernel_size.y; ++i) { + for (ulong j = 0; j < kernel_size.x; ++j) { + auto input_pos = input_offset + long2(j, i) * dilation; + if (input_pos.x < 0 || input_pos.y < 0 || input_pos.x >= input_size.x || input_pos.y >= input_size.y) { + *output = T(0); + } else { + auto offset = input_pos.x * input_strides.x + input_pos.y * input_strides.y; + *output = input[offset]; + } + output += output_stride; + } + } +} + +template +kernel void im2col( + constant T * inputData [[buffer(0)]], + device T * outputData [[buffer(1)]], + constant uint4 & kernel_dilation [[buffer(2)]], + constant int4 & padding_stride [[buffer(3)]], + constant ulong4 & input_strides [[buffer(4)]], + constant ulong4 & output_strides [[buffer(5)]], + constant long4 & input_sizes [[buffer(6)]], + uint3 thread_index [[thread_position_in_grid]]) { + // thread_index is (output_length, input_channels, input_batch) + const auto N = thread_index.z; + const auto C = thread_index.y; + const auto L = thread_index.x; + const auto output_width = output_strides.w; + const auto o_x = L % output_width; + const auto o_y = L / output_width; + auto i_x = o_x * padding_stride.z - padding_stride.x; + auto i_y = o_y * padding_stride.w - padding_stride.y; + ulong kernel_size = kernel_dilation.x * kernel_dilation.y; + outputData += N * output_strides.z + C * kernel_size * output_strides.y + L * output_strides.x; + inputData += N * input_strides.w + C * input_strides.z; + im2col_kernel(inputData, outputData, kernel_dilation.xy, long2(i_x, i_y), input_sizes.xy, long2(kernel_dilation.zw), input_strides.xy, output_strides.y); +} + +#define INSTANTIATE_IM2COL(DTYPE) \ +template \ +[[host_name("im2col_" #DTYPE)]] \ +kernel void im2col( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant uint4 & kernel_dilation [[buffer(2)]], \ + constant int4 & padding_stride [[buffer(3)]], \ + constant ulong4 & input_strides [[buffer(4)]], \ + constant ulong4 & output_strides [[buffer(5)]], \ + constant long4 & input_sizes [[buffer(6)]], \ + uint3 thread_index [[thread_position_in_grid]]) + +INSTANTIATE_IM2COL(bool); +INSTANTIATE_IM2COL(float); +INSTANTIATE_IM2COL(float2); +INSTANTIATE_IM2COL(half); +INSTANTIATE_IM2COL(half2); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_IM2COL(bfloat); +#endif +)IM2COL_METAL"); + +namespace { +static void im2col_out_mps_template(Tensor& output, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TORCH_CHECK(kernel_size.size() == 2, "It is expected kernel_size equals to 2, but got size ", kernel_size.size()); + + TORCH_CHECK(dilation.size() == 2, "It is expected dilation equals to 2, but got size ", dilation.size()); + + TORCH_CHECK(padding.size() == 2, "It is expected padding equals to 2, but got size ", padding.size()); + + TORCH_CHECK(stride.size() == 2, "It is expected stride equals to 2, but got size ", stride.size()); + + const auto kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + Tensor input = input_.contiguous(); + + bool batched_input = true; + + if (input.dim() == 3) { + batched_input = false; + input = input.unsqueeze(0); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t input_height = input.size(2); + int64_t input_width = input.size(3); + + int64_t output_height = + (input_height + 2 * pad_height - (dilation_height * (kernel_height - 1) + 1)) / stride_height + 1; + int64_t output_width = (input_width + 2 * pad_width - (dilation_width * (kernel_width - 1) + 1)) / stride_width + 1; + int64_t n_output_plane = n_input_plane * kernel_width * kernel_height; + int64_t output_length = output_height * output_width; + + output.resize_({batch_size, n_output_plane, output_length}); + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto im2colPSO = lib.getPipelineStateForFunc("im2col_" + mps::scalarToMetalTypeString(input)); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + std::array kernel_dilation = {static_cast(kernel_width), + static_cast(kernel_height), + static_cast(dilation_width), + static_cast(dilation_height)}; + std::array padding_stride = {static_cast(pad_width), + static_cast(pad_height), + static_cast(stride_width), + static_cast(stride_height)}; + std::array input_sizes = {input_width, input_height, n_input_plane, batch_size}; + std::array input_strides = {input.stride(3), input.stride(2), input.stride(1), input.stride(0)}; + std::array output_strides = {output.stride(2), output.stride(1), output.stride(0), output_width}; + getMPSProfiler().beginProfileKernel(im2colPSO, "im2col", {input, output}); + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:im2colPSO]; + mtl_setBuffer(computeEncoder, input, 0); + mtl_setBuffer(computeEncoder, output, 1); + mtl_setBytes(computeEncoder, kernel_dilation, 2); + mtl_setBytes(computeEncoder, padding_stride, 3); + mtl_setBytes(computeEncoder, input_strides, 4); + mtl_setBytes(computeEncoder, output_strides, 5); + mtl_setBytes(computeEncoder, input_sizes, 6); + [computeEncoder dispatchThreads:MTLSizeMake(output_length, n_input_plane, batch_size) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + getMPSProfiler().endProfileKernel(im2colPSO); + } + }); + if (!batched_input) { + output = output.squeeze(0); + } +} + +} // anonymous namespace +Tensor& im2col_out_mps(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + Tensor& output) { + im2col_out_mps_template(output, input, kernel_size, dilation, padding, stride); + return output; +} + +Tensor im2col_mps(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + im2col_out_mps_template(output, input, kernel_size, dilation, padding, stride); + return output; +} +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index e911281114802c..a13e660b9c857e 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -304,7 +304,6 @@ static Tensor nonzero_fallback(const Tensor& self) { if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && (self.numel() >= nonZeroMaxSize || self.is_complex())) { - https: // github.com/pytorch/pytorch/issues/122916 TORCH_WARN_ONCE("MPS: nonzero op is not natively supported for the provided input on MacOS14", "Falling back on CPU. This may have performance implications.", "See github.com/pytorch/pytorch/issues/122916 for further info"); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index b1158865c4e688..e40454307ac97f 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -5,6 +5,7 @@ #include #include // For MTLLanguageVersion_3_1 +#include #include #include @@ -509,11 +510,123 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L return output; } +static Tensor& tiled_bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + using namespace mps; + + id aBuffer = getMTLBufferStorage(batch1); + id bBuffer = getMTLBufferStorage(batch2); + id resBuffer = getMTLBufferStorage(result); + + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + id computeEncoder = mpsStream->commandEncoder(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + mpsStream->endKernelCoalescing(); + + uint64_t originalBatchSize = batch1.sizes().size() > 2 ? batch1.size(0) : 1; + uint64_t aRows = batch1.size(-2); + uint64_t bRows = batch2.size(-2); + uint64_t resRows = result.size(-2); + uint64_t aCols = batch1.size(-1); + uint64_t bCols = batch2.size(-1); + uint64_t resCols = result.size(-1); + uint64_t aElemSize = batch1.element_size(); + uint64_t bElemSize = batch2.element_size(); + uint64_t resElemSize = result.element_size(); + MPSDataType dtype = getMPSDataType(batch1); + + uint64_t elemInMatrix = resRows * resCols; + uint64_t largestSupportedBatchSize = floor(pow(2, 32) / elemInMatrix); + uint64_t batchSize = std::min(largestSupportedBatchSize, originalBatchSize); + uint64_t lastBatchSize = originalBatchSize % batchSize; + + id commandBuffer = mpsStream->commandBuffer(); + + auto matmul = [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; + + MPSShape* aShape = @[ @(batchSize), @(aRows), @(aCols) ]; + MPSShape* bShape = @[ @(batchSize), @(bRows), @(bCols) ]; + MPSShape* resShape = @[ @(batchSize), @(resRows), @(resCols) ]; + auto aDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:aShape]; + aDesc_.preferPackedRows = true; + auto bDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:bShape]; + bDesc_.preferPackedRows = true; + + auto resDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:resShape]; + resDesc_.preferPackedRows = true; + + getMPSProfiler().beginProfileKernel(matmul, " tiled_bmm_mps", {batch1, batch2}); + + // Descriptors to use for last batch if it exists + //.matrices is a readonly property so we need a separate descriptor. + MPSNDArrayDescriptor *aDescLastBatch_, *bDescLastBatch_, *resDescLastBatch_; + if (lastBatchSize != 0) { + aDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype + shape:@[ @(lastBatchSize), @(aRows), @(aCols) ]]; + aDescLastBatch_.preferPackedRows = true; + bDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype + shape:@[ @(lastBatchSize), @(bRows), @(bCols) ]]; + bDescLastBatch_.preferPackedRows = true; + resDescLastBatch_ = + [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(resRows), @(resCols) ]]; + resDescLastBatch_.preferPackedRows = true; + } + + uint64_t requiredIterations = ceil(float(originalBatchSize) / batchSize); + auto aDesc = aDesc_; + auto bDesc = bDesc_; + auto resDesc = resDesc_; + for (const auto i : c10::irange(requiredIterations)) { + if (i == requiredIterations - 1 && lastBatchSize != 0) { + aDesc = aDescLastBatch_; + bDesc = bDescLastBatch_; + resDesc = resDescLastBatch_; + } + const uint64_t aArrayOffset = i * batchSize * aRows * aCols; + const uint64_t bArrayOffset = i * batchSize * bRows * bCols; + const uint64_t resArrayOffset = i * batchSize * resRows * resCols; + + auto aMatrix = [[[MPSNDArray alloc] initWithBuffer:aBuffer + offset:(batch1.storage_offset() + aArrayOffset) * aElemSize + descriptor:aDesc] autorelease]; + auto bMatrix = [[[MPSNDArray alloc] initWithBuffer:bBuffer + offset:(batch2.storage_offset() + bArrayOffset) * bElemSize + descriptor:bDesc] autorelease]; + auto resMatrix = [[[MPSNDArray alloc] initWithBuffer:resBuffer + offset:(result.storage_offset() + resArrayOffset) * resElemSize + descriptor:resDesc] autorelease]; + + [matmul encodeToCommandEncoder:computeEncoder + commandBuffer:commandBuffer + sourceArrays:@[ aMatrix, bMatrix ] + destinationArray:resMatrix]; + } + } + }); + return result; + } else { + TORCH_CHECK(false, "Tiling of batch matmul for larger than 2**32 entries only available from MacOS15 onwards"); + } +} + static Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { using namespace mps; TORCH_CHECK(supportedFloatingOrComplexType(batch1), "MPS device does not support bmm for non-float inputs"); + // Currently unsupported if the matmul output goes over the 32-bit indexing limit + TORCH_CHECK( + batch1.size(1) * batch2.size(2) <= pow(2, 32), + "Output size of the matrix multiplication is larger than currently supported by the MPS backend: ", + batch1.size(1), + ",", + batch2.size(2), + ", needs to be less than 2**32 elements.", + "File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues"); + if (batch1.numel() == 0 || batch2.numel() == 0) { result.zero_(); return result; @@ -543,6 +656,13 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L } } + // Check if we need to split the batch to do the computation + uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2); + if (resultSize > pow(2, 32)) { + result = tiled_bmm_out_mps_impl(batch1, batch2, result); + return result; + } + MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public mps::MPSCachedGraph { diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index 271b4b2c492b34..c45054f8bc1638 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -437,6 +437,20 @@ static void nllnd_loss_forward_impl(Tensor& output, if (output.numel() == 0) return; + // https://github.com/pytorch/pytorch/blob/042f2f7746a064f1527d95d1f1d712b4f0b34186/aten/src/ATen/native/cuda/Loss.cu#L335-L346 + if (target_arg.numel() == 0) { + // Here target (and input) have zero elements + // Mean reduction on empty tensors produces NaN. See the discussion in + // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if (reduction == Reduction::Mean) { + output.fill_(std::numeric_limits::quiet_NaN()); + } else { + output.zero_(); + } + total_weight.zero_(); + return; + } + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; @@ -537,7 +551,9 @@ static void nllnd_loss_forward_impl(Tensor& output, mpsGraphBatchSizeTensor = [mpsGraph reductionSumWithTensor:mpsSelectOneTensor axes:nil name:@"batchSizeReductionTensor"]; - mpsGraphReducedTensor = divisionNoNaN(mpsGraph, mpsGraphReducedTensor, mpsGraphBatchSizeTensor); + mpsGraphReducedTensor = [mpsGraph divisionWithPrimaryTensor:mpsGraphReducedTensor + secondaryTensor:mpsGraphBatchSizeTensor + name:@"divisionTensor"]; } } diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 72337b8d340d7d..391422e77b5350 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -296,7 +296,9 @@ Check if running mean exists (maybe do this check before making graph) newCachedGraph->runningVarInplaceUpdate_ = runningVarInplaceUpdate; }); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape); + const auto needs_gather = memory_format != MemoryFormat::ChannelsLast; + auto inputPlaceholder = + Placeholder(cachedGraph->inputTensor_, self, input_shape, needs_gather, MPSDataTypeInvalid, needs_gather); auto weightPlaceholder = Placeholder(); if (has_weight) weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape); @@ -319,7 +321,8 @@ Check if running mean exists (maybe do this check before making graph) runningVarInplaceUpdatePlaceholder = Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value()); } - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, input_shape, false); + auto outputPlaceholder = + Placeholder(cachedGraph->outputTensor_, output, input_shape, false, MPSDataTypeInvalid, needs_gather); auto saveMeanPlaceholder = Placeholder(cachedGraph->saveMeanTensor_, save_mean); auto saveVarPlaceholder = Placeholder(cachedGraph->saveVarTensor_, save_var); @@ -714,7 +717,11 @@ static string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor secondaryTensor:epsilonTensor name:nil]; +#ifdef __MAC_15_0 + rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; +#else rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; +#endif MPSGraphTensor* bnForwardTensor = [mpsGraph multiplicationWithPrimaryTensor:xMinusMean secondaryTensor:rsqrtTensor name:nil]; @@ -739,7 +746,11 @@ static string get_mem_string(c10::MemoryFormat memory_format) { MPSGraphTensor* varianceEpsTensor = [mpsGraph additionWithPrimaryTensor:runningVarTensor secondaryTensor:epsilonTensor name:nil]; +#ifdef __MAC_15_0 + rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; +#else rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; +#endif } gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:unitTensor secondaryTensor:rsqrtTensor name:nil]; @@ -901,8 +912,7 @@ static string get_mem_string(c10::MemoryFormat memory_format) { for (const auto idx : c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); } - for (const auto idx : c10::irange(axis, input.dim())) { - (void)idx; // Suppress unused variable + for (C10_UNUSED auto idx : c10::irange(axis, input.dim())) { stat_shape.push_back(1); } mean = mean.view(stat_shape); diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 6d372f96b7f90f..cd0e75d84dc3bc 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -153,13 +153,16 @@ static void reduction_out_mps(const Tensor& input_t, const std::string& func_name) { bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS); MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name); + // NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor bool canSqueezeLastDim = true; IntArrayRef input_shape = input_t.sizes(); if (opt_dim.has_value()) { IntArrayRef dim = opt_dim.value(); for (const auto dim_val : dim) { auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); - if (wrap_dim >= 4) { + // canSqueeze logic is broken when dim is negative, it introduces off-by-one-erros or crashes + // See https://github.com/pytorch/pytorch/issues/136132#issuecomment-2354482608 + if (wrap_dim >= 4 || dim_val < 0) { canSqueezeLastDim = false; } TORCH_CHECK( diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index ea94faef98bdfe..334a056cddfb53 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -225,7 +225,11 @@ static void unary_op(const Tensor& self, CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sqrt_out_mps, squareRoot) +#ifdef __MAC_15_0 +CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(rsqrt_out_mps, reciprocalSquareRoot) +#else CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(rsqrt_out_mps, reverseSquareRoot) +#endif CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(neg_out_mps, negative) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(log_out_mps, logarithm) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(log10_out_mps, logarithmBase10) diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index f1bc8ebf97026c..ba23956b5d32a4 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -729,27 +729,6 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) { return kernelName + "_kernel_" + std::to_string(dim == 0 ? 1 : dim); } -static const std::string& getGatherScatterScalarType(const Tensor& t) { - auto scalar_type = t.scalar_type(); - static std::unordered_map scalarToMetalType = { - {c10::ScalarType::Float, "float"}, - {c10::ScalarType::Half, "half"}, - {c10::ScalarType::BFloat16, "bfloat"}, - {c10::ScalarType::Long, "long"}, - {c10::ScalarType::Int, "int"}, - {c10::ScalarType::Short, "short"}, - {c10::ScalarType::Char, "char"}, - {c10::ScalarType::Byte, "uchar"}, - {c10::ScalarType::Bool, "bool"}, - {c10::ScalarType::ComplexFloat, "float2"}, - {c10::ScalarType::ComplexHalf, "half2"}, - }; - - auto it = scalarToMetalType.find(scalar_type); - TORCH_CHECK(it != scalarToMetalType.end(), "Unsupported type byte size: ", scalar_type); - return it->second; -} - static std::string genScatterGatherCvtFunc(const std::string& dtypeSrc, const std::string& dtypeDst, bool needsConj) { const bool srcComplex = dtypeSrc[dtypeSrc.size() - 1] == '2'; const bool dstComplex = dtypeDst[dtypeDst.size() - 1] == '2'; @@ -805,8 +784,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { id computeEncoder = mpsStream->commandEncoder(); std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/false); id gatherPSO = getPipelineState(functionName, - getGatherScatterScalarType(src), - getGatherScatterScalarType(output), + scalarToMetalTypeString(src), + scalarToMetalTypeString(output), /*needsScatter=*/false, src.is_conj() != dst.is_conj()); @@ -862,8 +841,8 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst) { std::string functionName = getGatherScatterFunctionName(output.scalar_type(), output.dim(), /*needsScatter=*/true); id scatterPSO = getPipelineState(functionName, - getGatherScatterScalarType(src), - getGatherScatterScalarType(output), + scalarToMetalTypeString(src), + scalarToMetalTypeString(output), /*needsScatter=*/true, src.is_conj() != output.is_conj()); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b8168d1af4347b..83d04c4a14c958 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -537,9 +537,11 @@ - func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor tags: core + autogen: avg_pool1d.out - func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor tags: core + autogen: adaptive_avg_pool1d.out # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) @@ -3287,7 +3289,9 @@ autogen: native_layer_norm_backward.out tags: core -- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor +- func: rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + dispatch: + CompositeImplicitAutograd: rms_norm_symint - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method @@ -3398,9 +3402,9 @@ - func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor -- func: wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor +- func: _wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor -- func: wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor +- func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor - func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor @@ -3920,11 +3924,10 @@ tags: core # For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. -# FIXME: fix CI jobs and re-enable this -#- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) -# device_check: NoCheck # TensorIterator -# dispatch: -# CompositeExplicitAutograd: mean_dtype_out +- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + dispatch: + CompositeExplicitAutograd: mean_dtype_out - func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor structured_delegate: mean.out @@ -8553,7 +8556,7 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: __rshift__ + CPU, CUDA, MPS: __rshift__ tags: pointwise - func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) @@ -13042,12 +13045,14 @@ dispatch: CPU: im2col_out_cpu CUDA: im2col_out_cuda + MPS: im2col_out_mps - func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor python_module: nn dispatch: CPU: im2col_cpu CUDA: im2col_cuda + MPS: im2col_mps - func: isfinite(Tensor self) -> Tensor variants: function, method @@ -14705,6 +14710,11 @@ CUDA: _fbgemm_dense_to_jagged_forward_symint CPU: _padded_dense_to_jagged_forward_cpu +- func: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + variants: function + device_check: NoCheck + dispatch: {} + - func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor dispatch: NestedTensorCPU: NestedTensor_softmax_dropout @@ -14764,7 +14774,7 @@ CPU: _scaled_dot_product_flash_attention_cpu tags: nondeterministic_seeded -- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) dispatch: CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable tags: nondeterministic_seeded diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 16ed9195c08884..3cd2b6836d0668 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -581,7 +581,7 @@ inline std::tuple> check_shape_and_partition_( StackArray jagged_dims_tensor; const int num_jagged_dim = dense_tensor.dim() - 2; - TORCH_CHECK(num_jagged_dim <= kStackArrayMaxDims); + TORCH_CHECK(num_jagged_dim <= static_cast(kStackArrayMaxDims)); jagged_dims_tensor.ndim = num_jagged_dim; std::memcpy( &(jagged_dims_tensor.vals[0]), @@ -1220,7 +1220,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( // MI100 has independent shared mem and L1 int used_shared_kb = shared_kb; #endif - int used_shared_bytes = used_shared_kb << 10; + auto used_shared_bytes = static_cast(used_shared_kb << 10); AT_DISPATCH_INDEX_TYPES( x_offsets[0].scalar_type(), "check_shared_memory", [&] { auto B = y_0_reshaped.size(0); @@ -1355,7 +1355,7 @@ void jagged_dense_elementwise_jagged_output_opt_( int used_shared_bytes = calc_used_shared_bytes(y_reshaped.get_device()); set_max_dynamic_shared_mem_size_for_opt_search_kernel(used_shared_bytes); C10_CUDA_KERNEL_LAUNCH_CHECK(); - TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + TORCH_CHECK(dynamic_smem_size <= static_cast(used_shared_bytes)); } dim3 threads_bs = dim3(1024, 1, 1); dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); diff --git a/aten/src/ATen/native/quantized/cpu/TensorShape.cpp b/aten/src/ATen/native/quantized/cpu/TensorShape.cpp index 5fcdd22c034a5f..33334bef8f4346 100644 --- a/aten/src/ATen/native/quantized/cpu/TensorShape.cpp +++ b/aten/src/ATen/native/quantized/cpu/TensorShape.cpp @@ -163,10 +163,11 @@ Tensor cat_quantized_cpu(const ITensorListRef& qxs, int64_t dim) { TORCH_CHECK(is_valid_quantization_scheme(materialized[0]), "Only per-tensor quantization is supported in 'cat'!"); - if (all_inputs_sharing_qparams(materialized)) { + if (!all_inputs_sharing_qparams(materialized)) { // TODO: if possible change this warning to an error T194501002 TORCH_WARN("All inputs of this cat operator must share the same quantization parameters. Otherwise large numerical inaccuracies may occur."); - } check_cat_no_zero_dim(materialized); + } + check_cat_no_zero_dim(materialized); dim = legacy_cat_wrap_dim(dim, materialized); double _scale = materialized[0].get().q_scale(); int64_t _zero_point = materialized[0].get().q_zero_point(); diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h index 85451fb57482a3..9f2dfd26118acf 100644 --- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h +++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h @@ -313,9 +313,9 @@ c10::intrusive_ptr> deserialize_conv( output_padding.emplace_back(config_vals.at(idx)); idx++; } - int64_t groups = config_vals.at(idx); + int64_t groups [[maybe_unused]] = config_vals.at(idx); idx++; - int64_t flags = config_vals.at(idx); + int64_t flags [[maybe_unused]] = config_vals.at(idx); idx++; TORCH_INTERNAL_ASSERT(idx == static_cast(config_vals.size()), "Unexpected length of config_vals, expected ", @@ -323,7 +323,7 @@ c10::intrusive_ptr> deserialize_conv( " got ", config_vals.size()); - bool transpose = flags & (1 << 0); + bool transpose [[maybe_unused]] = flags & (1 << 0); int64_t other_flags = flags & ~(1 << 0); TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, "."); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 28fcab13ff3e06..f4c55b2a3cfe4f 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -436,12 +436,12 @@ at::Tensor wrapped_quantized_linear_meta( #endif // USE_FBGEMM } -at::Tensor wrapped_linear_prepack(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack(const at::Tensor& weight, const at::Tensor& weight_scale, const at::Tensor& weight_zero_point, const at::Tensor& bias); -at::Tensor wrapped_linear_prepack(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack(const at::Tensor& weight, const at::Tensor& weight_scale, const at::Tensor& weight_zero_point, const at::Tensor& bias) { @@ -474,14 +474,14 @@ at::Tensor wrapped_linear_prepack(const at::Tensor& weight, #endif // USE_FBGEMM } -at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, +at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, const at::Tensor& input_zero_point, const at::Tensor& packed_weight, const at::Tensor& output_scale, const at::Tensor& output_zero_point, [[maybe_unused]] const int64_t out_channel); -at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, +at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor& input, const at::Tensor& input_scale, const at::Tensor& input_zero_point, const at::Tensor& packed_weight, const at::Tensor& output_scale, @@ -507,12 +507,12 @@ at::Tensor wrapped_quantized_linear_prepacked(const at::Tensor& input, const at: #endif // USE_FBGEMM } -at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight, [[maybe_unused]] const at::Tensor& weight_scale, [[maybe_unused]] const at::Tensor& weight_zero_point, [[maybe_unused]] const at::Tensor& bias); -at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight, +at::Tensor _wrapped_linear_prepack_meta(const at::Tensor& weight, [[maybe_unused]] const at::Tensor& weight_scale, [[maybe_unused]] const at::Tensor& weight_zero_point, [[maybe_unused]] const at::Tensor& bias) { @@ -530,7 +530,7 @@ at::Tensor wrapped_linear_prepack_meta(const at::Tensor& weight, #endif // USE_FBGEMM } -at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, +at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, [[maybe_unused]] const at::Tensor& input_scale, [[maybe_unused]] const at::Tensor& input_zero_point, [[maybe_unused]] const at::Tensor& packed_weight, @@ -538,7 +538,7 @@ at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, [[maybe_unused]] const at::Tensor& output_zero_point, const int64_t out_channel); -at::Tensor wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, +at::Tensor _wrapped_quantized_linear_prepacked_meta(const at::Tensor& input, [[maybe_unused]] const at::Tensor& input_scale, [[maybe_unused]] const at::Tensor& input_zero_point, [[maybe_unused]] const at::Tensor& packed_weight, @@ -695,21 +695,21 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("_quantized::linear_prepack_fp16_legacy"), TORCH_FN(QLinearPackWeightFp16Legacy::run)); m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear)); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"), - wrapped_linear_prepack); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"), + _wrapped_linear_prepack); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"), - wrapped_quantized_linear_prepacked); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"), + _wrapped_quantized_linear_prepacked); } TORCH_LIBRARY_IMPL(_quantized, Meta, m) { m.impl(TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear"), TORCH_FN(wrapped_quantized_linear_meta)); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_linear_prepack"), - wrapped_linear_prepack_meta); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_linear_prepack"), + _wrapped_linear_prepack_meta); m.impl( - TORCH_SELECTIVE_NAME("_quantized::wrapped_quantized_linear_prepacked"), - wrapped_quantized_linear_prepacked_meta); + TORCH_SELECTIVE_NAME("_quantized::_wrapped_quantized_linear_prepacked"), + _wrapped_quantized_linear_prepacked_meta); } TORCH_LIBRARY_IMPL(onednn, CPU, m) { diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index be2dc9eb44f485..e3b95687306e8b 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -251,8 +251,8 @@ TORCH_LIBRARY(_quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); - m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_linear_prepack(Tensor W, Tensor W_scale, Tensor W_zero_point, Tensor B) -> Tensor"), {at::Tag::flexible_layout}); + m.def(TORCH_SELECTIVE_SCHEMA("_quantized::_wrapped_quantized_linear_prepacked(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor W_prepack, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor Y"), {at::Tag::flexible_layout}); } TORCH_LIBRARY(onednn, m) { diff --git a/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp b/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp index 947332635203ba..90d3e9cce6734f 100644 --- a/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp +++ b/aten/src/ATen/native/sparse/FlattenIndicesKernel.cpp @@ -27,5 +27,6 @@ REGISTER_AVX512_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); REGISTER_AVX2_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); REGISTER_VSX_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); +REGISTER_SVE256_DISPATCH(flatten_indices_stub, &flatten_indices_cpu_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index 6f20b4e245c7a3..e86d5c46a795fa 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -161,16 +161,19 @@ REGISTER_AVX512_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_ REGISTER_AVX2_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); REGISTER_VSX_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); +REGISTER_SVE256_DISPATCH(mul_sparse_sparse_out_stub, &mul_sparse_sparse_out_cpu_kernel); REGISTER_ARCH_DISPATCH(sparse_mask_intersection_out_stub, DEFAULT, &sparse_mask_intersection_out_cpu_kernel); REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); +REGISTER_SVE256_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel); REGISTER_ARCH_DISPATCH(sparse_mask_projection_out_stub, DEFAULT, &sparse_mask_projection_out_cpu_kernel); REGISTER_AVX512_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); REGISTER_AVX2_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); REGISTER_VSX_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); REGISTER_ZVECTOR_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); +REGISTER_SVE256_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_cpu_kernel); } diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index c31721729036ff..3544a3cf0b16c6 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -46,6 +46,15 @@ desc: | This tag indicates that the operator should be passed Tensors following the same stride permutation as observed in eager when compiled in inductor. + Only one of {needs_fixed_stride_order, flexible_layout} can apply; if + multiple are assigned then we assume the most restrictive one. +- tag: flexible_layout + desc: | + This tag indicates that the custom operator can accept inputs with varying + strides/storage_offset and that when compiled, Inductor is allowed to change + the strides/storage_offset of inputs to the custom operator. + Only one of {needs_fixed_stride_order, flexible_layout} can apply; if + multiple are assigned then we assume the most restrictive one. # NOTE [Core ATen Ops] - tag: core diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 5369e87d58bec3..d91955412fc588 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -449,6 +449,7 @@ REGISTER_AVX2_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); +REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp); int64_t _fused_sdp_choice_meta( const Tensor& query_, @@ -536,6 +537,24 @@ std::optional convert_boolean_attn_mask(const std::optional& att // Otherwise, attn_mask represents an additive attention tensor return attn_mask; } + +// alternate version to workaround -inf issue with cuDNN +// TODO(eqy): delete this when cuDNN -inf issue is resolved +std::optional convert_boolean_attn_mask_cudnn(const std::optional& attn_mask, caffe2::TypeMeta dtype) { + // Pass through + if(!attn_mask.has_value()){ + return std::nullopt; + } + // Convert boolean mask to additive mask; need to invert mask to indicate what + // to mask *out*. + if (attn_mask->dtype() == at::kBool) { + // TODO Use the max type of the input and output + return at::where(attn_mask->logical_not(), -65504.0, at::scalar_tensor(0.0, at::TensorOptions().dtype(dtype))); + } + // Otherwise, attn_mask represents an additive attention tensor + return attn_mask; +} + // Memory Efficient Attention requires a padded attn mask bias // This function pads the attn_mask bias to be a multiple of 16 // Then slices the padded bias to the original size @@ -698,15 +717,16 @@ Tensor scaled_dot_product_attention( query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa); } sdp::SDPBackend backend = static_cast(choice_int); - std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); switch (backend) { case sdp::SDPBackend::cudnn_attention: { + std::optional attn_mask = convert_boolean_attn_mask_cudnn(attn_mask_, query_.dtype()); bool compute_logsumexp = should_compute_logsumexp(query_, key, value); auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention( - query_, key, value, attn_mask_, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); + query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } case sdp::SDPBackend::flash_attention: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); if(query_.device().type() == DeviceType::CUDA){ c10::SymInt og_size = query_.sym_size(-1); Tensor query_padded = pad_last_dim<8, false>(query_); @@ -723,6 +743,7 @@ Tensor scaled_dot_product_attention( query_, key, value, dropout_p, is_causal, attn_mask, scale)); } case sdp::SDPBackend::efficient_attention: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); bool compute_logsumexp = should_compute_logsumexp(query_, key, value); if (attn_mask.has_value()) { attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);; @@ -732,12 +753,15 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_and_lse); } case sdp::SDPBackend::overrideable: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } - case sdp::SDPBackend::math: - if (query_.device().type() == DeviceType::MPS && dropout_p == 0.0 + case sdp::SDPBackend::math: { + std::optional attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); + if ((!GradMode::is_enabled() || (!query_.requires_grad() && !key.requires_grad() && !value.requires_grad())) + && query_.device().type() == DeviceType::MPS && dropout_p == 0.0 && query_.is_contiguous() && key.is_contiguous() && value.is_contiguous() && !query_.is_nested() && !key.is_nested() && !value.is_nested()) { return std::get<0>(at::_scaled_dot_product_attention_math_for_mps( @@ -760,6 +784,7 @@ Tensor scaled_dot_product_attention( std::nullopt, /*dropout_mask*/ scale, enable_gqa)); + } default: TORCH_CHECK( false, @@ -869,7 +894,6 @@ _scaled_dot_product_flash_attention_cpu( int64_t batchSize = query.size(0); int64_t qSize = query.size(2); int64_t num_head = query.size(1); - int64_t headSize = query.size(3); TORCH_CHECK(c10::isFloatingType(dtype), "scaled_dot_product_attention_flash_attention: Expected data type in FP32, FP64, BF16, FP16, but got ", dtype, " instead."); @@ -887,7 +911,7 @@ _scaled_dot_product_flash_attention_cpu( (attn_mask.value().dim() == 2 || attn_mask.value().dim() == 4), "scaled_dot_product_attention_flash_attention: Attention mask dim in {2, 4}"); - at::Tensor output = at::empty({batchSize, qSize, num_head, headSize}, query.options()); + at::Tensor output = at::empty_like(query, query.options()).transpose(1, 2); const auto accumulate_dtype = toOpMathType(dtype); at::Tensor logsumexp = at::empty({batchSize, qSize, num_head}, query.options().dtype(accumulate_dtype)); diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 78566555865c6d..914915da01a03d 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -774,6 +774,18 @@ std::tuple */, log_sumexp/*Tensor softmaxstats*/, attention/*Tensor o*/, cudnn_seed/*Tensor dropoutseed*/, cudnn_offset/*Tensor dropoutoffset*/); // TODO(eqy): support debug_attn_mask - return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor()); + return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } std::tuple _scaled_dot_product_efficient_attention_cuda( @@ -1102,10 +1115,13 @@ std::tuple _efficient_ offset_t = at::empty({}, at::dtype(at::kLong).device(device)); } else { auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor( - at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor( - at::Scalar(static_cast(offset)), at::dtype(at::kLong)); +#ifdef USE_ROCM + const auto options = at::dtype(at::kLong).device(at::kCUDA); +#else + const auto options = at::dtype(at::kLong); +#endif + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), options); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), options); } } else { // Not using dropout @@ -1118,7 +1134,8 @@ std::tuple _efficient_ auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } // AOTriton may accept aligned on logsumexp tensor in the future for better @@ -1147,8 +1164,16 @@ std::tuple _efficient_ using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16); at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options()); + const bool use_philox_state = in_capture_stream; + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); hipError_t err; // TODO: Error handling err = attn_fwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), @@ -1158,8 +1183,11 @@ std::tuple _efficient_ mk_aotensor<2>(softmax_lse, "M"), mk_aotensor(output_t, "Out"), dropout_p, - use_dropout ? *seed_t.data_ptr() : 0, - use_dropout ? *offset_t.data_ptr() : 0, + seed, + offset1, + offset2, + seed_output, + offset_output, mk_aotensor(softmax_fa_t, "encoded_softmax"), is_causal, stream); diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 33b95945988b49..07a398573665ac 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -195,6 +195,27 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ const int64_t num_heads = query.size(1); const int64_t head_dim_qk = query.size(3); const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_batch_q = query.size(2); + const int64_t max_seqlen_batch_k = key.size(2); + + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + } + } + const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); auto dq = at::empty_like(query); auto dk = at::empty_like(key); @@ -211,6 +232,7 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ query /*const Tensor& q*/, key /*const Tensor& k*/, value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, out /*const Tensor& o*/, grad_out/*const Tensor& dO*/, logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, @@ -219,7 +241,7 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ dv/*Tensor& dV*/, philox_seed/*Tensor& dropoutseed*/, philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(dq, dk, dv); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } std::tuple @@ -394,7 +416,8 @@ _efficient_attention_backward( auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "[AOTriton] Accelerated SDPA only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); bool is_causal; @@ -419,6 +442,7 @@ _efficient_attention_backward( hipError_t err; using aotriton::v2::flash::attn_bwd; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), @@ -435,8 +459,9 @@ _efficient_attention_backward( mk_aotensor<2>(softmax_lse, "L"), mk_aotensor<2>(delta, "delta"), float(dropout_p), - rng_engine_inputs.seed_.val, - rng_engine_inputs.offset_.val, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, is_causal, stream); #else diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 5580194a2aa80b..d84d9417692166 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -25,7 +25,10 @@ #include #if USE_ROCM +#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) #include +#define USE_AOTRITON 1 +#endif #endif /** @@ -207,7 +210,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug // Check that the gpu is capable of running flash attention using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; + auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -217,8 +222,19 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug } return false; } + c10::string_view arch(dprops->gcnArchName); + if (arch == "gfx1100") { + static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_navi3x) { + TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); + return false; + } + } +#else + return false; +#endif #else - auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -238,7 +254,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; + auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -248,8 +266,19 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } + c10::string_view arch(dprops->gcnArchName); + if (arch == "gfx1100") { + static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; + if (!enable_navi3x) { + TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental." + " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); + return false; + } + } +#else + return false; +#endif #else - auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -550,7 +579,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { check_cudnn_deterministic, // check_is_causal, check_dtypes_low_precision, - check_for_attn_mask_cudnn, + check_attn_mask_shape, check_cudnn_hardware_support ); for (auto& constraint : general_constraints) { @@ -605,9 +634,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { } } } +#if USE_ROCM + constexpr bool backend_supports_grouped_query_attention = false; +#else + constexpr bool backend_supports_grouped_query_attention = true; +#endif if (has_only_dense_inputs(params)) { constexpr auto dense_constraints = array_of( - check_batch_size_and_num_heads_dense, + check_batch_size_and_num_heads_dense, check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense); for (auto& constraint : dense_constraints) { @@ -641,7 +675,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { check_all_tensors_on_device, check_mem_efficient_hardware_support, check_tensor_shapes, - check_head_dim_size_mem_efficient); +#ifdef USE_ROCM + check_head_dim_size_flash +#else + check_head_dim_size_mem_efficient +#endif + ); for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index 1c238c751a05c9..57d5c34444390d 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -115,6 +115,18 @@ aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view ten cast_dtype(q.dtype())); } +inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q) +{ + return aotriton::TensorView<0>(reinterpret_cast(q.data_ptr()), + cast_dtype(q.dtype())); +} + +inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr) +{ + return aotriton::TensorView<0>(reinterpret_cast(ptr), + aotriton::DType::kUInt64); // AOTriton excepts unsigned int64 +} + } // namespace aotriton_adapter } // namespace sdp diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 6d363d5efa0752..92e51f85d8e540 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -72,7 +72,8 @@ void check_gpu_arch(hipStream_t stream) { auto ret = aotriton::v2::flash::check_gpu(stream); if (hipSuccess != ret) { TORCH_CHECK(false, - "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx942:sramecc+:xnack-)") + "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs" + " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)") } } @@ -164,6 +165,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); at::Tensor seed_t, offset_t; + at::PhiloxCudaState philox_state; + bool use_philox_state = false; if (p_dropout > 0.0) { // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -171,12 +174,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head int64_t counter_offset = batch_size * num_heads * 32; // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset); + philox_state = gen->philox_cuda_state(counter_offset); if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong)); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); } else { + // See Note [CUDA Graph-safe RNG states] about the design + use_philox_state = true; seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } @@ -185,19 +190,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } else { - seed_t = at::empty({}, at::dtype(at::kLong)); - offset_t = at::empty({}, at::dtype(at::kLong)); - } - } - - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } } @@ -219,9 +213,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; + using aotriton::TensorView; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); err = attn_fwd(mk_aotensor(q_t, "q"), mk_aotensor(k_t, "k"), mk_aotensor(v_t, "v"), @@ -230,8 +232,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head mk_aotensor<2>(M, "M"), mk_aotensor(output_t, "Out"), p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, + seed, + offset1, + offset2, + seed_output, + offset_output, mk_aotensor(softmax_fa_t, "encoded_softmax"), is_causal, stream); @@ -392,17 +397,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv_expanded = dv; } - at::PhiloxCudaState philox_args; - if (p_dropout > 0.0) { - if (at::cuda::currentStreamCaptureStatus() == - at::cuda::CaptureStatus::None) - { - philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); - } else { // dropout + capture - philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); - } - } - at::Tensor q_t = q.permute({0,2,1,3}); at::Tensor k_t = k.permute({0,2,1,3}); at::Tensor v_t = v.permute({0,2,1,3}); @@ -420,6 +414,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si { using aotriton::v2::flash::attn_bwd; using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); err = attn_bwd(mk_aotensor(q_t, "q"), @@ -436,8 +431,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si mk_aotensor<2>(softmax_lse_cont, "L"), mk_aotensor<2>(delta, "delta"), p_dropout, - philox_args.seed_.val, - philox_args.offset_.val, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, is_causal, stream); } diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index b22e75301e8772..272700392e1d71 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -275,17 +275,6 @@ inline bool check_for_attn_mask(sdp_params const& params, bool debug) { return true; } -// TODO(eqy): remove this once support is added -inline bool check_for_attn_mask_cudnn(sdp_params const& params, bool debug) { - if (params.attn_mask.has_value()) { - if (debug) { - TORCH_WARN("cuDNN Attention does not support non-null attn_mask."); - } - return false; - } - return true; -} - inline bool check_attn_mask_shape(sdp_params const& params, bool debug) { auto attn_mask = params.attn_mask; if (!attn_mask.has_value()) { diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp index 114fad05afaa51..93a631c31e3acd 100644 --- a/aten/src/ATen/native/vulkan/api/Pipeline.cpp +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -230,7 +230,7 @@ void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept { rhs.handle_ = tmp_handle; } -bool operator==( +static bool operator==( const ComputePipeline::Descriptor& _1, const ComputePipeline::Descriptor& _2) { return ( diff --git a/aten/src/ATen/native/vulkan/api/QueryPool.cpp b/aten/src/ATen/native/vulkan/api/QueryPool.cpp index e74b7431b0f795..9c0c7fb2ea860c 100644 --- a/aten/src/ATen/native/vulkan/api/QueryPool.cpp +++ b/aten/src/ATen/native/vulkan/api/QueryPool.cpp @@ -170,13 +170,7 @@ void QueryPool::extract_results() { results_pending_ = false; } -std::ostream& operator<<(std::ostream& os, const VkExtent3D& extents) { - os << "{" << extents.width << ", " << extents.height << ", " << extents.depth - << "}"; - return os; -} - -std::string stringize(const VkExtent3D& extents) { +static std::string stringize(const VkExtent3D& extents) { std::stringstream ss; ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth << "}"; diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index ebed09d47e7d4e..4981f54accfd3e 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -253,7 +253,7 @@ BufferMemoryBarrier::BufferMemoryBarrier( // ImageSampler // -bool operator==( +static bool operator==( const ImageSampler::Properties& _1, const ImageSampler::Properties& _2) { return ( diff --git a/aten/src/ATen/native/vulkan/api/Tensor.cpp b/aten/src/ATen/native/vulkan/api/Tensor.cpp index f89cd0fb9aaca0..bbaaef1d1ddb8f 100644 --- a/aten/src/ATen/native/vulkan/api/Tensor.cpp +++ b/aten/src/ATen/native/vulkan/api/Tensor.cpp @@ -460,7 +460,7 @@ void vTensor::virtual_resize(const std::vector& new_sizes) { // vTensorStorage // -api::VulkanImage allocate_image( +static api::VulkanImage allocate_image( api::Context* const context_ptr, api::utils::uvec3& extents, const api::StorageType storage_type, @@ -505,7 +505,7 @@ api::VulkanImage allocate_image( /*allocate_memory = */ allocate_memory); } -api::VulkanBuffer allocate_buffer( +static api::VulkanBuffer allocate_buffer( api::Context* const context_ptr, const int64_t numel, const api::StorageType storage_type, diff --git a/aten/src/ATen/native/vulkan/impl/Packing.cpp b/aten/src/ATen/native/vulkan/impl/Packing.cpp index 3c47208c4805fc..10d918ffbbfd20 100644 --- a/aten/src/ATen/native/vulkan/impl/Packing.cpp +++ b/aten/src/ATen/native/vulkan/impl/Packing.cpp @@ -296,7 +296,7 @@ bool record_buffer_to_nchw_op( v_src.buffer_metadata()); } -vTensor channel_image_repacking( +static vTensor channel_image_repacking( const vTensor& v_input, api::GPUMemoryLayout target_layout, const api::ShaderInfo& shader_descriptor) { diff --git a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp index e12e69c4ebec2a..8fbcc54bd47400 100644 --- a/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Batchnorm.cpp @@ -15,7 +15,7 @@ struct Params final { float eps; }; -void record_op( +static void record_op( api::Context* const context, vTensor& v_output, const vTensor& v_input, diff --git a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp index e1445f40ac5f87..bb4e5187a4c8e5 100644 --- a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp +++ b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp @@ -1,9 +1,9 @@ +#ifdef USE_VULKAN_API #include #include #include #include #include -#include namespace at { namespace native { @@ -12,7 +12,7 @@ namespace ops { using namespace api::utils; -Tensor binary_op_scalar( +static Tensor binary_op_scalar( const Tensor& self_arg, const Scalar& other, const std::optional& alpha_arg, @@ -66,7 +66,7 @@ Tensor binary_op_scalar( return convert(v_output); } -Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { +static Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { // Similar to binary_op_scalar where tensors is mapped to float, we // also map known integer types (but not quant types) tensor to float. @@ -99,7 +99,7 @@ Tensor binary_op_preprocess_other_arg(const Tensor& other_arg) { return other; } -Tensor& binary_op_scalar_( +static Tensor& binary_op_scalar_( Tensor& self_arg, const Scalar& other, const std::optional& alpha_arg, @@ -149,7 +149,7 @@ Tensor& binary_op_scalar_( return self_arg; } -Tensor binary_op_tensor( +static Tensor binary_op_tensor( const Tensor& self_arg, const Tensor& other_arg, const std::optional& alpha_arg, @@ -222,7 +222,7 @@ Tensor binary_op_tensor( return convert(v_output); } -Tensor quantized_binary_op_tensor( +static Tensor quantized_binary_op_tensor( const Tensor& self_arg, const Tensor& other_arg, const double scale, @@ -310,7 +310,7 @@ Tensor quantized_binary_op_tensor( return convert_quantized(v_output); } -Tensor& binary_op_tensor_( +static Tensor& binary_op_tensor_( Tensor& self_arg, const Tensor& other_arg, const std::optional& alpha_arg, @@ -384,7 +384,7 @@ Tensor& binary_op_tensor_( return self_arg; } -Tensor add_scalar( +static Tensor add_scalar( const Tensor& self_arg, const Scalar& other, const Scalar& alpha) { @@ -392,7 +392,10 @@ Tensor add_scalar( self_arg, other, std::optional(alpha), VK_KERNEL(add_scalar)); } -Tensor& add_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { +static Tensor& add_scalar_( + Tensor& self, + const Scalar& other, + const Scalar& alpha) { return binary_op_scalar_( self, other, std::optional(alpha), VK_KERNEL(add_scalar_inplace)); } @@ -433,7 +436,7 @@ Tensor quantized_div( self_arg, other_arg, scale, zero_point, VK_KERNEL(quantized_div)); } -Tensor add_tensor( +static Tensor add_tensor( const Tensor& self_arg, const Tensor& other_arg, const Scalar& alpha) { @@ -441,7 +444,7 @@ Tensor add_tensor( self_arg, other_arg, std::optional(alpha), VK_KERNEL(add)); } -Tensor& add_tensor_( +static Tensor& add_tensor_( Tensor& self, const Tensor& other_arg, const Scalar& alpha) { @@ -449,7 +452,7 @@ Tensor& add_tensor_( self, other_arg, std::optional(alpha), VK_KERNEL(add_inplace)); } -Tensor sub_scalar( +static Tensor sub_scalar( const Tensor& self_arg, const Scalar& other, const Scalar& alpha) { @@ -460,7 +463,10 @@ Tensor sub_scalar( VK_KERNEL(add_scalar)); } -Tensor& sub_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { +static Tensor& sub_scalar_( + Tensor& self, + const Scalar& other, + const Scalar& alpha) { return binary_op_scalar_( self, other, @@ -468,7 +474,7 @@ Tensor& sub_scalar_(Tensor& self, const Scalar& other, const Scalar& alpha) { VK_KERNEL(add_scalar_inplace)); } -Tensor sub_tensor( +static Tensor sub_tensor( const Tensor& self_arg, const Tensor& other_arg, const Scalar& alpha) { @@ -476,7 +482,7 @@ Tensor sub_tensor( self_arg, other_arg, std::optional(alpha), VK_KERNEL(sub)); } -Tensor& sub_tensor_( +static Tensor& sub_tensor_( Tensor& self, const Tensor& other_arg, const Scalar& alpha) { @@ -484,27 +490,27 @@ Tensor& sub_tensor_( self, other_arg, std::optional(alpha), VK_KERNEL(sub_inplace)); } -Tensor mul_scalar(const Tensor& self_arg, const Scalar& other) { +static Tensor mul_scalar(const Tensor& self_arg, const Scalar& other) { return binary_op_scalar( self_arg, other, std::optional(), VK_KERNEL(mul_scalar)); } -Tensor& mul_scalar_(Tensor& self, const Scalar& other) { +static Tensor& mul_scalar_(Tensor& self, const Scalar& other) { return binary_op_scalar_( self, other, std::optional(), VK_KERNEL(mul_scalar_inplace)); } -Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) { +static Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) { return binary_op_tensor( self_arg, other_arg, std::optional(), VK_KERNEL(mul)); } -Tensor& mul_tensor_(Tensor& self, const Tensor& other_arg) { +static Tensor& mul_tensor_(Tensor& self, const Tensor& other_arg) { return binary_op_tensor_( self, other_arg, std::optional(), VK_KERNEL(mul_inplace)); } -Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { +static Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { return binary_op_scalar( self_arg, 1.0 / other.to(), @@ -512,7 +518,7 @@ Tensor div_scalar(const Tensor& self_arg, const Scalar& other) { VK_KERNEL(mul_scalar)); } -Tensor& div_scalar_(Tensor& self, const Scalar& other) { +static Tensor& div_scalar_(Tensor& self, const Scalar& other) { return binary_op_scalar_( self, 1.0 / other.to(), @@ -520,31 +526,31 @@ Tensor& div_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(mul_scalar_inplace)); } -Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) { +static Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) { return binary_op_tensor( self_arg, other_arg, std::optional(), VK_KERNEL(div)); } -Tensor& div_tensor_(Tensor& self, const Tensor& other_arg) { +static Tensor& div_tensor_(Tensor& self, const Tensor& other_arg) { return binary_op_tensor_( self, other_arg, std::optional(), VK_KERNEL(div_inplace)); } -Tensor pow(const Tensor& self, const Tensor& other) { +static Tensor pow(const Tensor& self, const Tensor& other) { return binary_op_tensor(self, other, std::optional(), VK_KERNEL(pow)); } -Tensor& pow_(Tensor& self, const Tensor& other) { +static Tensor& pow_(Tensor& self, const Tensor& other) { return binary_op_tensor_( self, other, std::optional(), VK_KERNEL(pow_inplace)); } -Tensor pow_tensor_scalar(const Tensor& self, const Scalar& other) { +static Tensor pow_tensor_scalar(const Tensor& self, const Scalar& other) { return binary_op_scalar( self, other, std::optional(), VK_KERNEL(pow_tensor_scalar)); } -Tensor& pow_tensor_scalar_(Tensor& self, const Scalar& other) { +static Tensor& pow_tensor_scalar_(Tensor& self, const Scalar& other) { return binary_op_scalar_( self, other, @@ -552,12 +558,12 @@ Tensor& pow_tensor_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(pow_tensor_scalar_inplace)); } -Tensor pow_scalar_tensor(const Scalar& self, const Tensor& other) { +static Tensor pow_scalar_tensor(const Scalar& self, const Tensor& other) { return binary_op_scalar( other, self, std::optional(), VK_KERNEL(pow_scalar_tensor)); } -Tensor floor_divide_scalar(const Tensor& self, const Scalar& other) { +static Tensor floor_divide_scalar(const Tensor& self, const Scalar& other) { TORCH_CHECK( other.to() != 0.0f, "floor_divide_scalar: can't divide by zero"); return binary_op_scalar( @@ -567,7 +573,7 @@ Tensor floor_divide_scalar(const Tensor& self, const Scalar& other) { VK_KERNEL(floor_mul_scalar)); } -Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { +static Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { TORCH_CHECK( other.to() != 0.0f, "floor_divide_scalar_: can't divide by zero"); return binary_op_scalar_( @@ -577,12 +583,12 @@ Tensor& floor_divide_scalar_(Tensor& self, const Scalar& other) { VK_KERNEL(floor_mul_scalar_inplace)); } -Tensor floor_divide_tensor(const Tensor& self, const Tensor& other) { +static Tensor floor_divide_tensor(const Tensor& self, const Tensor& other) { return binary_op_tensor( self, other, std::optional(), VK_KERNEL(floor_divide)); } -Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { +static Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { return binary_op_tensor_( self, other_arg, @@ -590,8 +596,6 @@ Tensor& floor_divide_tensor_(Tensor& self, const Tensor& other_arg) { VK_KERNEL(floor_divide_inplace)); } -#ifdef USE_VULKAN_API - TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::add.Scalar"), TORCH_FN(add_scalar)); m.impl(TORCH_SELECTIVE_NAME("aten::add_.Scalar"), TORCH_FN(add_scalar_)); @@ -631,9 +635,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { TORCH_FN(floor_divide_tensor_)); } -#endif /* USE_VULKAN_API */ - } // namespace ops } // namespace vulkan } // namespace native } // namespace at +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index f210c253800b19..3831cce8937367 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -53,7 +53,7 @@ inline bool is_pointwise(const IntArrayRef weight_size) { return true; } -Conv2dMethod determine_method( +static Conv2dMethod determine_method( const IntArrayRef weight_size, const IntArrayRef stride, const IntArrayRef padding, @@ -359,7 +359,7 @@ struct Params final { api::utils::vec2 clamp; }; -void record_op( +static void record_op( api::Context* const context, api::ShaderInfo& compute_shader, vTensor& v_output, @@ -432,7 +432,7 @@ struct QParams final { api::utils::vec2 clamp; }; -void record_quantized_op( +static void record_quantized_op( api::Context* const context, api::ShaderInfo& compute_shader, vTensor& v_output, @@ -787,56 +787,11 @@ Tensor convolution( input, c10::make_intrusive(conv_context)); } -Tensor quantized_convolution( - const Tensor& input, - const Tensor& weight, - const std::optional& bias, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const bool transposed, - const IntArrayRef output_padding, - const int64_t groups, - const double out_scale, - const int64_t out_zero_point) { - if (transposed) { - return run_tconv2d_context( - input, - c10::make_intrusive(Conv2dPackedContext( - weight, - bias, - stride, - padding, - dilation, - transposed, - false, - output_padding, - groups))); - } - - Conv2dPackedContext conv_context = Conv2dPackedContext( - weight, - bias, - stride, - padding, - dilation, - transposed, - true, - output_padding, - groups); - - return run_qconv2d_context( - input, - out_scale, - out_zero_point, - c10::make_intrusive(conv_context)); -} - } // namespace namespace conv1d { -vTensor pack_weights_using_width_packing(const Tensor& weight_arg) { +static vTensor pack_weights_using_width_packing(const Tensor& weight_arg) { Tensor weight = weight_arg; if (weight.is_cpu()) { @@ -862,7 +817,7 @@ vTensor pack_weights_using_width_packing(const Tensor& weight_arg) { * This is a full implementation. For algorithm details, refer to the shader * kernel code. */ -Tensor run_conv1d_context_impl( +static Tensor run_conv1d_context_impl( const Tensor& input_arg, const Tensor& weight_arg, const std::optional& bias_arg_opt, @@ -1150,7 +1105,7 @@ c10::intrusive_ptr create_qtconv2d_context( output_max)); } -Tensor run_conv2d_context_impl( +static Tensor run_conv2d_context_impl( const Tensor& input_arg, const c10::intrusive_ptr& conv_context, double scale, @@ -1291,30 +1246,6 @@ Tensor run_qconv2d_context( return run_conv2d_context_impl(input_arg, conv_context, scale, zero_point); } -Tensor quantized_conv2d( - const Tensor& input, - const Tensor& weight, - const std::optional& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - double out_scale, - int64_t out_zero_point) { - return quantized_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - false, - {{0, 0}}, - groups, - out_scale, - out_zero_point); -} - /* Backwards compatibility */ Conv2dOpContext::Conv2dOpContext(Conv2dPackedContext conv_context) : conv_context_{std::move(conv_context)} {} @@ -1444,7 +1375,7 @@ c10::intrusive_ptr create_conv1d_context( Conv1dPackedContext(weight, bias, stride, padding, dilation, groups)); } -Tensor convolution1d( +static Tensor convolution1d( const Tensor& input, const Tensor& weight, const std::optional& bias, diff --git a/aten/src/ATen/native/vulkan/ops/Copy.cpp b/aten/src/ATen/native/vulkan/ops/Copy.cpp index 60bc3a341ba0df..d8dcee9391fbe2 100644 --- a/aten/src/ATen/native/vulkan/ops/Copy.cpp +++ b/aten/src/ATen/native/vulkan/ops/Copy.cpp @@ -123,7 +123,7 @@ void transfer_vulkan_to_cpu(vTensor& v_src, Tensor& dst) { .to(convert_dtype(v_src.dtype())); } -void transfer_vulkan_to_vulkan(vTensor& src, vTensor& dst) { +static void transfer_vulkan_to_vulkan(vTensor& src, vTensor& dst) { api::Context* const context = api::context(); api::PipelineBarrier pipeline_barrier{}; diff --git a/aten/src/ATen/native/vulkan/ops/Factory.cpp b/aten/src/ATen/native/vulkan/ops/Factory.cpp index 153a6448eaf4ab..d8cd21eb659f3e 100644 --- a/aten/src/ATen/native/vulkan/ops/Factory.cpp +++ b/aten/src/ATen/native/vulkan/ops/Factory.cpp @@ -28,7 +28,7 @@ Tensor _empty_affine_quantized( }); } -Tensor empty_memory_format( +static Tensor empty_memory_format( const IntArrayRef sizes, const std::optional dtype, const std::optional layout, @@ -46,7 +46,7 @@ Tensor empty_memory_format( }); } -Tensor empty_strided( +static Tensor empty_strided( const IntArrayRef sizes, const IntArrayRef /* strides */, const std::optional dtype, diff --git a/aten/src/ATen/native/vulkan/ops/Gru.cpp b/aten/src/ATen/native/vulkan/ops/Gru.cpp index a66c69b134cee1..e803e9b9686f92 100644 --- a/aten/src/ATen/native/vulkan/ops/Gru.cpp +++ b/aten/src/ATen/native/vulkan/ops/Gru.cpp @@ -135,7 +135,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { } // namespace -std::vector> pack_linear_op_contexts( +static std::vector> +pack_linear_op_contexts( const std::vector& params_cpu, int64_t num_layers) { TORCH_CHECK( diff --git a/aten/src/ATen/native/vulkan/ops/Layernorm.cpp b/aten/src/ATen/native/vulkan/ops/Layernorm.cpp index 6b6a4b866c700c..f2d285e736f4a5 100644 --- a/aten/src/ATen/native/vulkan/ops/Layernorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Layernorm.cpp @@ -78,7 +78,7 @@ Tensor run_layernorm_context( return std::get<0>(native_layer_norm_output); } -Tensor layer_norm( +static Tensor layer_norm( const at::Tensor& input_arg, IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, diff --git a/aten/src/ATen/native/vulkan/ops/Lstm.cpp b/aten/src/ATen/native/vulkan/ops/Lstm.cpp index 7e8000370346aa..63f17f15eb2c14 100644 --- a/aten/src/ATen/native/vulkan/ops/Lstm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Lstm.cpp @@ -171,7 +171,7 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { } // namespace -std::vector> +static std::vector> pack_lstm_linear_op_contexts( const std::vector& params_cpu, int64_t num_layers) { diff --git a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp index 81f1a9c0197ab2..228a1ed5262a67 100644 --- a/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp +++ b/aten/src/ATen/native/vulkan/ops/QuantizedTensor.cpp @@ -1,3 +1,4 @@ +#ifdef USE_VULKAN_API #include #include #include @@ -161,13 +162,13 @@ Tensor dequantize_helper( return convert(v_output); } -double q_scale(const Tensor& self) { +static double q_scale(const Tensor& self) { TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_scale"); const vTensor& v_input = convert(self); return v_input.get_scale(); } -int64_t q_zero_point(const Tensor& self) { +static int64_t q_zero_point(const Tensor& self) { TORCH_CHECK(self.is_vulkan(), "Expecting a vulkan tensor for q_zero_point"); const vTensor& v_input = convert(self); return v_input.get_zero_point(); @@ -179,8 +180,6 @@ Tensor dequantize(const Tensor& self) { return dequantize_helper(self, q_scale, zero_point, kFloat); } -#ifdef USE_VULKAN_API - TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl( TORCH_SELECTIVE_NAME("aten::quantize_per_tensor"), quantize_per_tensor); @@ -192,9 +191,8 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::dequantize.self"), dequantize); } -#endif /* USE_VULKAN_API */ - } // namespace ops } // namespace vulkan } // namespace native } // namespace at +#endif /* USE_VULKAN_API */ diff --git a/aten/src/ATen/native/vulkan/ops/Random.cpp b/aten/src/ATen/native/vulkan/ops/Random.cpp index 3103f7fe6f58d1..49199b48cb9709 100644 --- a/aten/src/ATen/native/vulkan/ops/Random.cpp +++ b/aten/src/ATen/native/vulkan/ops/Random.cpp @@ -12,7 +12,9 @@ namespace ops { using namespace api::utils; -Tensor& uniform_( +#ifdef USE_VULKAN_API + +static Tensor& uniform_( Tensor& self, const double from, const double to, @@ -57,7 +59,7 @@ Tensor& uniform_( return self; } -Tensor rand_like( +static Tensor rand_like( const at::Tensor& input_arg, const std::optional /* not implemented */, const std::optional /* not implemented */, @@ -71,7 +73,7 @@ Tensor rand_like( return input_arg.clone().detach().uniform_(0.0, 1.0); } -Tensor& normal_( +static Tensor& normal_( Tensor& self, const double mean, const double std, @@ -118,7 +120,7 @@ Tensor& normal_( return self; } -Tensor randn_like( +static Tensor randn_like( const at::Tensor& input_arg, const std::optional /* not implemented */, const std::optional /* not implemented */, @@ -130,8 +132,6 @@ Tensor randn_like( return input_arg.clone().detach().normal_(0.0, 1.0); } -#ifdef USE_VULKAN_API - TORCH_LIBRARY_IMPL(aten, Vulkan, m) { m.impl(TORCH_SELECTIVE_NAME("aten::uniform_"), TORCH_FN(uniform_)); m.impl(TORCH_SELECTIVE_NAME("aten::rand_like"), TORCH_FN(rand_like)); diff --git a/aten/src/ATen/native/vulkan/ops/Shape.cpp b/aten/src/ATen/native/vulkan/ops/Shape.cpp index 2a13979523f638..fd10a05a7c87f2 100644 --- a/aten/src/ATen/native/vulkan/ops/Shape.cpp +++ b/aten/src/ATen/native/vulkan/ops/Shape.cpp @@ -8,7 +8,7 @@ namespace native { namespace vulkan { namespace ops { -Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { +static Tensor view_internal(const Tensor& self_arg, const IntArrayRef shape) { api::Context* const context = api::context(); Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); @@ -52,7 +52,7 @@ inline Tensor view(const Tensor& self_arg, IntArrayRef shape) { return view_internal(self_arg, shape); } -Tensor _reshape_alias( +static Tensor _reshape_alias( const Tensor& self_arg, const IntArrayRef shape, const IntArrayRef strides) { diff --git a/aten/src/ATen/native/vulkan/ops/Upsample.cpp b/aten/src/ATen/native/vulkan/ops/Upsample.cpp index 7e3a2ead2d632f..fc426e2da73838 100644 --- a/aten/src/ATen/native/vulkan/ops/Upsample.cpp +++ b/aten/src/ATen/native/vulkan/ops/Upsample.cpp @@ -9,7 +9,7 @@ namespace vulkan { namespace ops { using namespace api::utils; -Tensor upsample_nearest2d( +static Tensor upsample_nearest2d( const Tensor& input_arg, const IntArrayRef output_sizes, const std::optional scales_h, @@ -94,7 +94,7 @@ Tensor upsample_nearest2d( return convert(v_output); } -Tensor upsample_bilinear2d( +static Tensor upsample_bilinear2d( const Tensor& input_arg, const IntArrayRef output_sizes, bool align_corners, diff --git a/aten/src/ATen/native/vulkan/ops/Utils.cpp b/aten/src/ATen/native/vulkan/ops/Utils.cpp index 1e6c18cfa43b1e..36d4221c666bf6 100644 --- a/aten/src/ATen/native/vulkan/ops/Utils.cpp +++ b/aten/src/ATen/native/vulkan/ops/Utils.cpp @@ -1,5 +1,6 @@ #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 63fbcb55e96d2b..125c1d8c491101 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -319,7 +319,7 @@ struct TORCH_API RecordFunction { if (!isActive()) { return; } - kwinputs_ = *kwargs; + kwinputs_ = std::unordered_map(*kwargs); before(std::move(fn), args, current_sequence_nr); } diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index f9a0557f8bdfff..a2c8da12c446bf 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -992,6 +992,9 @@ namespace { blend_init(a, b); test_blendv(expected_val, a, b, mask); } +// NOTE: In this test, blend is not required to implement SVE Vectorized::set. +// so, this test is disabled for SVE. +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(BitwiseFloatsAdditional2, Blend) { using vec = TypeParam; using VT = ValueType; @@ -1005,6 +1008,7 @@ namespace { constexpr int64_t power_sets = 1LL << (vec::size()); test_blend(expected_val, a, b); } +#endif template // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) void test_set(VT expected_val[vec::size()], VT a[vec::size()], VT b[vec::size()], int64_t count){ @@ -1606,6 +1610,7 @@ namespace { ASSERT_TRUE(vec_pinf.has_inf_nan()) << "Test failed for positive Infinity\n"; ASSERT_TRUE(vec_ninf.has_inf_nan()) << "Test failed for negative Infinity\n"; } +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecConvertTests, Convert) { using vec = TypeParam; using src_t = ValueType; @@ -1658,43 +1663,91 @@ namespace { TEST_CONVERT_TO(double); #undef TEST_CONVERT_TO } +#endif TYPED_TEST(VecMaskTests, MaskedLoad) { using vec = TypeParam; - using VT = ValueType; - constexpr auto N = vec::size(); - CACHE_ALIGN VT x[N]; - CACHE_ALIGN VT y[N]; - CACHE_ALIGN VT ref[N]; - auto seed = TestSeed(); - ValueGen generator(VT(-100), VT(100), seed); - for (const auto i : c10::irange(N)) { - x[i] = generator.get(); - } - auto vec_mask = generate_vec_mask(seed); - auto x_vec = vec_mask.template loadu(x); - x_vec.store(y); - for (const auto i : c10::irange(N)) { - if (vec_mask.is_masked(i)) { - ref[i] = x[i]; - } else { - ref[i] = 0; - } - } - for (const auto i : c10::irange(N)) { - ASSERT_EQ(y[i], ref[i]) - << "Failure Details:\nTest Seed to reproduce: " << seed; - } + using src_t = ValueType; + constexpr auto size = vec::size(); + + #define TEST_MASK_LOAD(dst_t, mask_t, mask_n) \ + do { \ + CACHE_ALIGN dst_t x[mask_n * size]; \ + CACHE_ALIGN dst_t y[mask_n * size]; \ + CACHE_ALIGN dst_t ref[mask_n * size]; \ + auto seed = TestSeed(); \ + ValueGen generator(dst_t(-100), dst_t(100), seed); \ + for (const auto i : c10::irange(mask_n * size)) { \ + x[i] = generator.get(); \ + } \ + auto vec_mask = generate_vec_mask(seed); \ + constexpr int dst_size = at::vec::Vectorized::size(); \ + constexpr int dst_n = mask_n * size / dst_size; \ + constexpr int rnd_n = (mask_n * size + dst_size - 1) / dst_size; \ + if constexpr(dst_n * dst_size >= mask_n * size) { \ + auto x_vec = vec_mask.template loadu(x); \ + x_vec.store(y); \ + for (const auto i : c10::irange(mask_n * size)) { \ + if (vec_mask.is_masked(i)) { \ + ref[i] = x[i]; \ + } else { \ + ref[i] = 0; \ + } \ + } \ + for (const auto i : c10::irange(mask_n * size)) { \ + ASSERT_EQ(y[i], ref[i]) \ + << "Failure Details:\nTest Seed to reproduce: " << seed; \ + } \ + } \ + } while (0) + + + #define TEST_MASK_LOAD_N(N) \ + TEST_MASK_LOAD(int8_t, src_t, N); \ + TEST_MASK_LOAD(uint8_t, src_t, N); \ + TEST_MASK_LOAD(int16_t, src_t, N); \ + TEST_MASK_LOAD(uint16_t, src_t, N); \ + TEST_MASK_LOAD(int32_t, src_t, N); \ + TEST_MASK_LOAD(uint32_t, src_t, N); \ + TEST_MASK_LOAD(int64_t, src_t, N); \ + TEST_MASK_LOAD(uint64_t, src_t, N); \ + TEST_MASK_LOAD(c10::BFloat16, src_t, N); \ + TEST_MASK_LOAD(c10::Half, src_t, N); \ + TEST_MASK_LOAD(float, src_t, N); \ + TEST_MASK_LOAD(double, src_t, N); + + TEST_MASK_LOAD_N(1) + TEST_MASK_LOAD_N(2) + TEST_MASK_LOAD_N(4) + + #undef TEST_MASK_LOAD + #undef TEST_MASK_LOAD_N } +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, MaskedCheck) { using VT = ValueType; - auto vec_mask = create_vec_mask(0); - ASSERT_TRUE(vec_mask.all_zero()) << "all_zero check failed"; - vec_mask = create_vec_mask(-1); - ASSERT_TRUE(vec_mask.all_masked()) << "all_masked check failed"; - vec_mask = create_vec_mask(2); - ASSERT_TRUE(vec_mask.is_masked(1)) << "is_masked(1) check failed"; - ASSERT_TRUE(!vec_mask.is_masked(0)) << "!is_masked(0) check failed"; + using vec = TypeParam; + constexpr auto size = vec::size(); + #define TEST_MASK_CHECK_N(N) \ + do { \ + auto vec_mask = create_vec_mask(0); \ + ASSERT_TRUE(vec_mask.all_zero()) << "all_zero check failed"; \ + vec_mask = create_vec_mask(-1); \ + ASSERT_TRUE(vec_mask.all_masked()) << "all_masked check failed"; \ + vec_mask = create_vec_mask(2); \ + for (int i = 0; i < N; i ++) { \ + ASSERT_TRUE(vec_mask.is_masked(1 + i * size)) << "is_masked(1) check failed"; \ + ASSERT_TRUE(!vec_mask.is_masked(0 + i * size)) << "!is_masked(0) check failed"; \ + } \ + } while (0) + + TEST_MASK_CHECK_N(1); + TEST_MASK_CHECK_N(2); + TEST_MASK_CHECK_N(4); + + #undef TEST_MASK_CHECK_N } +#endif +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, ToFrom) { using vec = TypeParam; using VT = ValueType; @@ -1720,41 +1773,53 @@ namespace { << "Failure Details:\nTest Seed to reproduce: " << seed; } } +#endif +#if !defined(CPU_CAPABILITY_SVE) TYPED_TEST(VecMaskTests, Cast) { using vec = TypeParam; using src_t = ValueType; - constexpr auto N = vec::size(); - #define TEST_MASK_CAST(dst_t) \ + constexpr auto size = vec::size(); + + #define TEST_MASK_CAST(dst_t, mask_t, mask_n) \ do { \ - CACHE_ALIGN src_t x[N]; \ - CACHE_ALIGN dst_t y[N]; \ + CACHE_ALIGN mask_t x[mask_n * size]; \ + CACHE_ALIGN dst_t y[mask_n * size]; \ auto seed = TestSeed(); \ - auto vec_mask = generate_vec_mask(seed); \ + auto vec_mask = generate_vec_mask(seed); \ constexpr int num_dst_elements = \ - std::min(N, at::vec::Vectorized::size()); \ - constexpr int dst_n = N / num_dst_elements; \ + std::min(size, at::vec::Vectorized::size()); \ + constexpr int dst_n = mask_n * size / num_dst_elements; \ auto vec_mask_new = vec_mask.template cast(); \ - vec_mask.template to().store(x); \ - vec_mask_new.template to().store(y, N); \ - for (const auto i : c10::irange(N)) { \ + vec_mask.template to().store(x); \ + vec_mask_new.template to().store(y); \ + for (const auto i : c10::irange(mask_n * size)) { \ ASSERT_EQ(y[i], x[i]) \ << "Failure Details:\nTest Seed to reproduce: " << seed; \ } \ } while (0) - TEST_MASK_CAST(int8_t); - TEST_MASK_CAST(uint8_t); - TEST_MASK_CAST(int16_t); - TEST_MASK_CAST(uint16_t); - TEST_MASK_CAST(int32_t); - TEST_MASK_CAST(uint32_t); - TEST_MASK_CAST(int64_t); - TEST_MASK_CAST(uint64_t); - TEST_MASK_CAST(c10::BFloat16); - TEST_MASK_CAST(c10::Half); - TEST_MASK_CAST(float); - TEST_MASK_CAST(double); + + #define TEST_MASK_CAST_N(N) \ + TEST_MASK_CAST(int8_t, src_t, N); \ + TEST_MASK_CAST(uint8_t, src_t, N); \ + TEST_MASK_CAST(int16_t, src_t, N); \ + TEST_MASK_CAST(uint16_t, src_t, N); \ + TEST_MASK_CAST(int32_t, src_t, N); \ + TEST_MASK_CAST(uint32_t, src_t, N); \ + TEST_MASK_CAST(int64_t, src_t, N); \ + TEST_MASK_CAST(uint64_t, src_t, N); \ + TEST_MASK_CAST(c10::BFloat16, src_t, N); \ + TEST_MASK_CAST(c10::Half, src_t, N); \ + TEST_MASK_CAST(float, src_t, N); \ + TEST_MASK_CAST(double, src_t, N); + + TEST_MASK_CAST_N(1) + TEST_MASK_CAST_N(2) + TEST_MASK_CAST_N(4) + #undef TEST_MASK_CAST + #undef TEST_MASK_CAST_N } +#endif #else #error GTEST does not have TYPED_TEST #endif diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 91788fcd56039d..9215e9ff393f38 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -53,6 +53,9 @@ CACHE_ALIGN #define defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__)) #undef CHECK_DEQUANT_WITH_LOW_PRECISION #define CHECK_WITH_FMA 1 +#elif defined(CPU_CAPABILITY_SVE) +#define CHECK_DEQUANT_WITH_LOW_PRECISION 1 +#define CHECK_WITH_FMA 1 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) #undef CHECK_DEQUANT_WITH_LOW_PRECISION #undef CHECK_WITH_FMA @@ -1434,22 +1437,24 @@ double getDefaultTolerance() { return 1.e-9; } -template -at::vec::VecMask create_vec_mask(uint64_t bitmask) { - constexpr auto N = at::vec::Vectorized::size(); - std::array mask; - for (int i = 0; i < N; i++) { - mask[i] = (bitmask >> i) & 1; +template +at::vec::VecMask create_vec_mask(uint64_t bitmask) { + constexpr auto size = at::vec::Vectorized::size(); + std::array mask; + for (int n = 0; n < N; n++) { + for (int i = 0; i < size; i++) { + mask[n * size + i] = (bitmask >> i) & 1; + } } - return at::vec::VecMask::from(mask.data()); + return at::vec::VecMask::from(mask.data()); } -template -at::vec::VecMask generate_vec_mask(int seed) { - constexpr auto N = at::vec::Vectorized::size(); - ValueGen generator(0, (1ULL << N) - 1, seed); +template +at::vec::VecMask generate_vec_mask(int seed) { + constexpr auto size = at::vec::Vectorized::size(); + ValueGen generator(0, (1ULL << size) - 1, seed); auto bitmask = generator.get(); - return create_vec_mask(bitmask); + return create_vec_mask(bitmask); } template diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index 904ba00fbdd759..c96684bc79462c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv index f1b119c0284058..914849fa010cc1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 @@ -278,8 +278,8 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 -yolov3,fail_accuracy,8 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv index 784d3788e33554..1cafcbe55675d3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_huggingface_inference.csv @@ -10,10 +10,6 @@ AlbertForQuestionAnswering,pass,0 -AllenaiLongformerBase,fail_to_run,0 - - - BartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv index c7e86a6d317eb5..dc36107e0d0292 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv @@ -114,7 +114,7 @@ lcnet_050,pass,0 -levit_128,pass,0 +levit_128,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index 6cbf71ffeded09..fe3c67bba120b3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -6,7 +6,7 @@ torchrec_dlrm,eager_fail_to_run,0 -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -282,7 +282,7 @@ sam,fail_to_run,0 -sam_fast,fail_to_run,0 +sam_fast,timeout,0 @@ -346,4 +346,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv similarity index 85% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv index fae39359c6148f..71345480d423d8 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_amp_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -98,7 +90,7 @@ detectron2_maskrcnn_r_50_fpn,fail_to_run,0 -dlrm,fail_to_run,0 +dlrm,pass,0 @@ -138,7 +130,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,pass,0 @@ -166,10 +158,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +170,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -214,18 +194,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,14 +254,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - shufflenet_v2_x1_0,pass,0 @@ -302,10 +266,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -346,7 +306,7 @@ torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +318,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_huggingface_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv similarity index 97% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_huggingface_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv index 784d3788e33554..1cafcbe55675d3 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_huggingface_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_huggingface_inference.csv @@ -10,10 +10,6 @@ AlbertForQuestionAnswering,pass,0 -AllenaiLongformerBase,fail_to_run,0 - - - BartForCausalLM,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_timm_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_timm_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_timm_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv similarity index 85% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv index 5e5e0a0d507ff5..282c18feea4f64 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_aot_inductor_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -98,7 +90,7 @@ detectron2_maskrcnn_r_50_fpn,fail_to_run,0 -dlrm,fail_to_run,0 +dlrm,pass,0 @@ -138,7 +130,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,pass,0 @@ -166,10 +158,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +170,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -214,18 +194,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,14 +254,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - shufflenet_v2_x1_0,pass,0 @@ -302,10 +266,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -346,7 +306,7 @@ torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +318,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_huggingface_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_timm_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv similarity index 95% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index c900d1768921b6..dafbd90e9aa793 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_amp_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -10,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 - - - LearningToPaint,pass,0 @@ -142,7 +138,7 @@ hf_Bert_large,pass,0 -hf_BigBird,pass,61 +hf_BigBird,pass,19 @@ -194,11 +190,11 @@ maml,pass_due_to_skip,0 -maml_omniglot,pass,0 +mnasnet1_0,pass,0 -mnasnet1_0,pass,0 +maml_omniglot,pass,0 @@ -214,10 +210,6 @@ mobilenet_v3_large,pass,0 -moco,model_fail_to_load,0 - - - moondream,pass,0 @@ -346,7 +338,7 @@ vgg16,pass,0 -vision_maskrcnn,fail_accuracy,28 +vision_maskrcnn,fail_accuracy,30 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_huggingface_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv similarity index 100% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_timm_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_timm_inference.csv diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv similarity index 94% rename from benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index 10f3904d678b72..a897806e5188b7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -10,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,model_fail_to_load,0 - - - LearningToPaint,pass,0 @@ -82,7 +78,7 @@ detectron2_fcos_r_50_fpn,pass,23 -detectron2_maskrcnn_r_101_c4,fail_accuracy,57 +detectron2_maskrcnn_r_101_c4,pass,57 @@ -142,6 +138,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,19 + + + hf_DistilBert,pass,0 @@ -154,10 +154,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -202,10 +210,6 @@ mobilenet_v3_large,pass,0 -moco,model_fail_to_load,0 - - - moondream,pass,0 @@ -334,7 +338,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,27 +vision_maskrcnn,pass,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index dc589b9c49c10a..c8980699f9615d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -138,6 +138,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,13 + + + hf_DistilBert,pass,0 @@ -150,10 +154,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -330,7 +342,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,27 +vision_maskrcnn,pass,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 7a14abcf8e2910..9075a4adfd3a1a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -366,7 +366,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv index 38ff07764e08d2..5d2fe7d197768e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 @@ -274,8 +274,8 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 -yolov3,fail_accuracy,8 +yolov3,pass,8 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_amp_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv similarity index 77% rename from benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_amp_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv index 4abe5ae064a93a..9fe2b93f08e814 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_amp_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_amp_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -82,23 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_to_run,0 -detectron2_maskrcnn_r_101_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_101_fpn,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_fpn,fail_to_run,0 - - - -dlrm,fail_to_run,0 +dlrm,pass,0 @@ -138,7 +114,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,pass,0 @@ -166,10 +142,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +154,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -198,11 +162,11 @@ maml_omniglot,pass,0 -mnasnet1_0,pass,0 +mobilenet_v2,pass,0 -mobilenet_v2,pass,0 +mnasnet1_0,pass,0 @@ -214,18 +178,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,14 +238,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - shufflenet_v2_x1_0,pass,0 @@ -302,10 +250,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -342,11 +286,11 @@ timm_vovnet,pass,0 -torch_multimodal_clip,fail_to_run,0 +torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +302,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_freezing_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv similarity index 77% rename from benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_freezing_inference.csv rename to benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv index 4360e2858cce24..98fe4427622601 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_torchbench_freezing_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_aot_inductor_freezing_torchbench_inference.csv @@ -2,11 +2,7 @@ name,accuracy,graph_breaks -torchrec_dlrm,eager_fail_to_run,0 - - - -BERT_pytorch,fail_to_run,0 +BERT_pytorch,pass,0 @@ -14,10 +10,6 @@ Background_Matting,pass_due_to_skip,0 -DALLE2_pytorch,fail_to_run,0 - - - LearningToPaint,pass,0 @@ -82,23 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_to_run,0 -detectron2_maskrcnn_r_101_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_101_fpn,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_c4,fail_to_run,0 - - - -detectron2_maskrcnn_r_50_fpn,fail_to_run,0 - - - -dlrm,fail_to_run,0 +dlrm,pass,0 @@ -138,7 +114,7 @@ hf_Bert_large,pass,0 -hf_BigBird,fail_to_run,0 +hf_BigBird,pass,0 @@ -166,10 +142,6 @@ hf_T5_large,pass_due_to_skip,0 -hf_Whisper,pass,0 - - - hf_distil_whisper,pass,0 @@ -182,14 +154,6 @@ llama,fail_to_run,0 -llama_v2_7b_16h,model_fail_to_load,0 - - - -llava,model_fail_to_load,0 - - - maml,pass_due_to_skip,0 @@ -214,18 +178,10 @@ mobilenet_v3_large,pass,0 -moco,fail_to_run,0 - - - moondream,pass,0 -nanogpt,pass,0 - - - nvidia_deeprecommender,pass,0 @@ -282,14 +238,6 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 - - - -sam_fast,fail_to_run,0 - - - shufflenet_v2_x1_0,pass,0 @@ -302,10 +250,6 @@ squeezenet1_1,pass,0 -stable_diffusion_text_encoder,pass,0 - - - stable_diffusion_unet,pass_due_to_skip,0 @@ -342,11 +286,11 @@ timm_vovnet,pass,0 -torch_multimodal_clip,fail_to_run,0 +torch_multimodal_clip,pass,0 -tts_angular,fail_to_run,0 +tts_angular,pass,0 @@ -358,4 +302,4 @@ vision_maskrcnn,fail_to_run,0 -yolov3,pass,0 +yolov3,fail_to_run,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index cba443a815dd97..9b6ec5b6cddecf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -122,6 +122,10 @@ hf_Bert_large,pass,0 +hf_BigBird,pass,13 + + + hf_DistilBert,pass,0 @@ -134,10 +138,18 @@ hf_GPT2_large,pass_due_to_skip,0 +hf_Longformer,pass,4 + + + hf_Reformer,pass,5 +hf_T5,pass,0 + + + hf_T5_base,pass,0 @@ -314,7 +326,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,27 +vision_maskrcnn,pass,29 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 7a14abcf8e2910..9075a4adfd3a1a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -366,7 +366,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index 66b96ee134c9b7..ab99edec8b4ecf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 @@ -274,7 +274,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index 904ba00fbdd759..c96684bc79462c 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv index 05d52bedb3e046..914849fa010cc1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index aea126aa270cd8..f21050a3d3d959 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -370,7 +370,7 @@ vgg16,pass,0 -vision_maskrcnn,pass,16 +vision_maskrcnn,pass,18 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index 05d52bedb3e046..914849fa010cc1 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -110,7 +110,7 @@ hf_GPT2_large,pass_due_to_skip,0 -hf_Reformer,pass,25 +hf_Reformer,pass,23 @@ -278,7 +278,7 @@ vgg16,pass,6 -vision_maskrcnn,pass,33 +vision_maskrcnn,pass,35 diff --git a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py index 29d204c02e218d..c6bd92ae8ed015 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/update_expected.py +++ b/benchmarks/dynamo/ci_expected_accuracy/update_expected.py @@ -152,8 +152,16 @@ def apply_lints(filename): [ "aot_eager", "aot_inductor", + "cpu_aot_inductor", + "cpu_aot_inductor_amp_freezing", + "cpu_aot_inductor_freezing", "cpu_inductor", + "cpu_inductor_amp_freezing", + "cpu_inductor_freezing", "dynamic_aot_eager", + "dynamic_cpu_aot_inductor", + "dynamic_cpu_aot_inductor_amp_freezing", + "dynamic_cpu_aot_inductor_freezing", "dynamic_cpu_inductor", "dynamic_inductor", "dynamo_eager", diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 6457e1b0814d1a..104e59bc193a41 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -39,6 +39,7 @@ from unittest.mock import MagicMock import numpy as np +import numpy.typing as npt import pandas as pd import psutil import yaml @@ -158,6 +159,7 @@ class CI(NamedTuple): "detectron2_fasterrcnn_r_50_fpn", "hf_T5_generate", "Reformer", + "llama", }.union(INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY) # These models currently fail accuracy with eager Adam optimizer @@ -1338,6 +1340,16 @@ def try_script(model, example_inputs): return None +def _produce_dynamic_shapes_for_export(path, x): + # mark_dynamic() is ignored for export. + # use this to produce dynamic_shapes spec instead. + from torch.export.dynamic_shapes import Dim + + if not isinstance(x, torch.Tensor): + return None + return {i: Dim.AUTO for i in getattr(x, "_dynamo_dynamic_indices", {})} + + class AOTInductorModelCache: cache = {} @@ -1345,6 +1357,7 @@ class AOTInductorModelCache: def load(cls, model, example_inputs, device): import torch._inductor import torch.export._trace + from torch.export.dynamic_shapes import _tree_map_with_path key = weakref.ref(model) if key not in cls.cache: @@ -1364,14 +1377,19 @@ def load(cls, model, example_inputs, device): else: _register_dataclass_output_as_pytree(example_outputs) - # TODO(angelayi): change this to predispatch - # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing - # to predispatch to avoid performance regressions - gm = torch.export._trace._export_to_torch_ir( + combined_args = tuple(example_args) + tuple(example_kwargs.values()) + dynamic_shapes = _tree_map_with_path( + _produce_dynamic_shapes_for_export, combined_args + ) + + gm = torch.export._trace._export( model, example_args, example_kwargs, - ) + dynamic_shapes=dynamic_shapes, + pre_dispatch=True, + strict=False, + ).module() with torch.no_grad(): so_path = torch._inductor.aot_compile( gm, example_args, example_kwargs @@ -1383,11 +1401,20 @@ def load(cls, model, example_inputs, device): def export(model, example_inputs): + from torch.export.dynamic_shapes import _tree_map_with_path + example_args, example_kwargs = _normalize_bench_inputs(example_inputs) example_outputs = model(*example_args, **example_kwargs) _register_dataclass_output_as_pytree(example_outputs) - ep = torch.export.export(model, example_args, example_kwargs) + combined_args = tuple(example_args) + tuple(example_kwargs.values()) + dynamic_shapes = _tree_map_with_path( + _produce_dynamic_shapes_for_export, combined_args + ) + + ep = torch.export.export( + model, example_args, example_kwargs, dynamic_shapes=dynamic_shapes + ) def opt_export(_, example_inputs): example_args, example_kwargs = _normalize_bench_inputs(example_inputs) @@ -1539,14 +1566,14 @@ def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: ... - def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]: + def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, npt.NDArray]: pt_inputs = self.format_pt_inputs(pt_inputs) return { ort_input.name: pt_input.cpu().numpy() for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs) } - def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any: + def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[npt.NDArray]) -> Any: pt_outputs = [ torch.from_numpy(onnx_output).to(current_device) for onnx_output in onnx_outputs @@ -2371,7 +2398,11 @@ def skip_models_for_cpu(self): return set() @property - def skip_models_for_freezing(self): + def skip_models_for_freezing_cpu(self): + return set() + + @property + def skip_models_for_freezing_cuda(self): return set() @property @@ -4276,7 +4307,6 @@ def run(runner, args, original_dir=None): runner.skip_models.update(runner.slow_models) if args.devices == ["cpu"]: - runner.skip_models.update(runner.very_slow_models) runner.skip_models.update(runner.skip_models_for_cpu) elif args.devices == ["cuda"]: runner.skip_models.update(runner.skip_models_for_cuda) @@ -4285,7 +4315,10 @@ def run(runner, args, original_dir=None): runner.skip_models.update(runner.skip_multiprocess_models) if args.freezing: - runner.skip_models.update(runner.skip_models_for_freezing) + if args.devices == ["cpu"]: + runner.skip_models.update(runner.skip_models_for_freezing_cpu) + elif args.devices == ["cuda"]: + runner.skip_models.update(runner.skip_models_for_freezing_cuda) if args.no_skip: runner.skip_models.clear() diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 40268c0773ae74..06bf4f0ee7610a 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -505,7 +505,7 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name): return 4e-3, cosine if ( current_device == "cpu" - and name in self._config["tolerance"]["higher_inference"] + and name in self._config["tolerance"]["higher_inference_cpu"] ): return 4e-3, cosine return 1e-3, cosine diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index 9650defcbfac66..2ddc242537d6e7 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -11,9 +11,7 @@ skip: - GPTJForQuestionAnswering device: - cpu: - # OOMs - - OPTForCausalLM + cpu: [] control_flow: - AllenaiLongformerBase @@ -71,6 +69,7 @@ batch_size: TrOCRForCausalLM: 2 XGLMForCausalLM: 4 XLNetLMHeadModel: 2 + YituTechConvBert: 2 tolerance: diff --git a/benchmarks/dynamo/microbenchmarks/operatorbench.py b/benchmarks/dynamo/microbenchmarks/operatorbench.py index 2ebca7e340336c..d61ec36870563f 100644 --- a/benchmarks/dynamo/microbenchmarks/operatorbench.py +++ b/benchmarks/dynamo/microbenchmarks/operatorbench.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from contextlib import nullcontext + import click import numpy as np from operator_inp_utils import OperatorInputsLoader @@ -16,11 +18,13 @@ aten = torch.ops.aten +profile_enabled = False def compute_speedups( operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda" ): + global profile_enabled expected = models[0](*example_inputs) if accuracy_checking: for model in models[1:]: @@ -35,20 +39,32 @@ def compute_speedups( timings = np.zeros((repeats, len(models)), np.float64) for rep in range(repeats): - # interleave the runs to handle frequency scaling and load changes - for m, model in enumerate(models): - if device == "cuda": - model(*example_inputs) - - # benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time - # along with cuda synchronization - timings[rep, m] = benchmarker.benchmark_gpu( - lambda: model(*example_inputs) + record_rep_context = ( + torch.profiler.record_function(f"rep_{rep}") + if profile_enabled + else nullcontext() + ) + with record_rep_context: + # interleave the runs to handle frequency scaling and load changes + for m, model in enumerate(models): + record_model_context = ( + torch.profiler.record_function(f"model_{m}") + if profile_enabled + else nullcontext() ) - else: - from torch._inductor.utils import timed - - timings[rep, m] = timed(model, example_inputs) + with record_model_context: + if device == "cuda": + model(*example_inputs) + + # benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time + # along with cuda synchronization + timings[rep, m] = benchmarker.benchmark_gpu( + lambda: model(*example_inputs) + ) + else: + from torch._inductor.utils import timed + + timings[rep, m] = timed(model, example_inputs) return np.median(timings, axis=0) @@ -171,6 +187,7 @@ def skip_operator(operator): @click.option( "--channels-last", help="force inputs to channels last", is_flag=True, default=False ) +@click.option("--profile", help="profile the benchmark", is_flag=True, default=False) def benchmark( suite, op, @@ -183,7 +200,9 @@ def benchmark( inp_file, start_idx, channels_last, + profile, ): + global profile_enabled if inp_file is not None: loader = OperatorInputsLoader(inp_file) else: @@ -209,6 +228,8 @@ def benchmark( ops = [eval(op)] max_samples = max_samples + start_idx + profile_enabled = profile + for operator in ops: if skip_operator(operator): continue @@ -216,10 +237,31 @@ def benchmark( print(f"Running {operator}") inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device) timings = [] - - for i in range(min(max_samples, 1000000)): + inputs_list = [] + for _ in range(min(max_samples, 1000000)): try: inps = next(inp_gen) + inputs_list.append(inps) + except StopIteration: + break + + profiler_context = ( + torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + profile_memory=False, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + f"./log/operator_{operator}", use_gzip=True + ), + ) + if profile_enabled + else nullcontext() + ) + with profiler_context as prof: + for i, inps in enumerate(inputs_list): if inps is None: break if i < start_idx: @@ -230,28 +272,32 @@ def benchmark( args, kwargs = tree_map_only( torch.Tensor, to_channels_last, (args, kwargs) ) - - except StopIteration: - break - try: - # aten, nvfuser, inductor - timings.append( - microbenchmark( - operator, - args, - kwargs, - dtype, - accuracy_checking, - repeats, - measure_nvfuser, - device, + try: + iter_context = ( + torch.profiler.record_function(f"iter_{i}") + if profile_enabled + else nullcontext() ) - ) - except Exception as e: - print(f"error {operator}") - print(e) - # comment out this line to avoid blocking other tests - # raise e + with iter_context: + # aten, nvfuser, inductor + timings.append( + microbenchmark( + operator, + args, + kwargs, + dtype, + accuracy_checking, + repeats, + measure_nvfuser, + device, + ) + ) + + except Exception as e: + print(f"error {operator}") + print(e) + # comment out this line to avoid blocking other tests + # raise e if not timings: continue diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py index e9da5b3a45c519..0c0d0e7ecfa7b4 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -1,9 +1,12 @@ import csv +import gc from abc import ABC, abstractmethod from fbscribelogger import make_scribe_logger import torch._C._instruction_counter as i_counter +import torch._dynamo.config as config +from torch._dynamo.utils import CompileTimeInstructionCounter scribe_log_torch_benchmark_compile_time = make_scribe_logger( @@ -51,10 +54,26 @@ class BenchmarkBase(ABC): - _instruction_count = False + # measure total number of instruction spent in _work. + _enable_instruction_count = False + + # measure total number of instruction spent in convert_frame.compile_inner + # TODO is there other parts we need to add ? + _enable_compile_time_instruction_count = False + + # number of iterations used to run when collecting instruction_count or compile_time_instruction_count. + _num_iterations = 5 + + def with_iterations(self, value): + self._num_iterations = value + return self def enable_instruction_count(self): - self._instruction_count = True + self._enable_instruction_count = True + return self + + def enable_compile_time_instruction_count(self): + self._enable_compile_time_instruction_count = True return self def name(self): @@ -64,31 +83,56 @@ def description(self): return "" @abstractmethod - def prepare(self): + def _prepare(self): pass @abstractmethod - def work(self): + def _work(self): pass - def prepare_once(self): # noqa: B027 + def _prepare_once(self): # noqa: B027 pass - def count_instructions(self): + def _count_instructions(self): print(f"collecting instruction count for {self.name()}") - self.prepare_once() - results = [] - for i in range(10): - self.prepare() + for i in range(self._num_iterations): + self._prepare() id = i_counter.start() - self.work() + self._work() count = i_counter.end(id) print(f"instruction count for iteration {i} is {count}") - if i != 0: - results.append(count) + results.append(count) return min(results) + def _count_compile_time_instructions(self): + gc.disable() + + try: + print(f"collecting compile time instruction count for {self.name()}") + config.record_compile_time_instruction_count = True + + results = [] + for i in range(self._num_iterations): + self._prepare() + gc.collect() + # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner + # hence this will only count instruction count spent in compile_inner. + CompileTimeInstructionCounter.clear() + self._work() + count = CompileTimeInstructionCounter.value() + if count == 0: + raise RuntimeError( + "compile time instruction count is 0, please check your benchmarks" + ) + print(f"compile time instruction count for iteration {i} is {count}") + results.append(count) + + config.record_compile_time_instruction_count = False + return min(results) + finally: + gc.enable() + def append_results(self, path): with open(path, "a", newline="") as csvfile: # Create a writer object @@ -102,12 +146,36 @@ def print(self): print(f"{entry[0]},{entry[1]},{entry[2]}") def collect_all(self): + self._prepare_once() self.results = [] - if self._instruction_count: - r = self.count_instructions() + if ( + self._enable_instruction_count + and self._enable_compile_time_instruction_count + ): + raise RuntimeError( + "not supported until we update the logger, both logs to the same field now" + ) + + if self._enable_instruction_count: + r = self._count_instructions() self.results.append((self.name(), "instruction_count", r)) scribe_log_torch_benchmark_compile_time( name=self.name(), instruction_count=r, ) + if self._enable_compile_time_instruction_count: + r = self._count_compile_time_instructions() + + self.results.append( + ( + self.name(), + "compile_time_instruction_count", + r, + ) + ) + # TODO add a new field compile_time_instruction_count to the logger. + scribe_log_torch_benchmark_compile_time( + name=self.name(), + instruction_count=r, + ) return self diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh index 66ea3f0e0d8fd2..a5cf04173358c5 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh @@ -18,8 +18,17 @@ output_file=$1 # Set the directory of Python programs python_programs_dir=$2 # Loop through all files in the directory of Python programs + +start=`date +%s` + for file in $python_programs_dir/*.py do # Execute the Python program and append the output to the output file python $file $output_file done +end=`date +%s` + +runtime=$((end-start)) +echo "total time to run benchmarks is:" +echo $runtime +python benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py $runtime diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py new file mode 100644 index 00000000000000..f28d59f154ea28 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/add_loop.py @@ -0,0 +1,67 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch +from torch._inductor.utils import fresh_inductor_cache + + +class Benchmark(BenchmarkBase): + def __init__(self, backend, dynamic=False, is_gpu=False): + self._backend = backend + self._dynamic = dynamic + self._device = "cuda" if is_gpu else "cpu" + + def name(self): + prefix = f"add_loop_{self._backend}" + if self._dynamic: + prefix += "_dynamic" + if self._device == "cuda": + prefix += "_gpu" + return prefix + + def description(self): + return "a loop over 100 add node" + + def _prepare_once(self): + self.a = torch.ones(1000, device=self._device) + self.b = torch.torch.ones(1000, device=self._device) + + def _prepare(self): + torch._dynamo.reset() + + def _work(self): + @torch.compile(backend=self._backend, fullgraph=True, dynamic=self._dynamic) + def f(a, b): + result = a.clone() + for i in range(1000): + if i % 3 == 0: + result = result + b + elif i % 3 == 1: + result = result + 8 * b + else: + result = result.sin() + return result + + with fresh_inductor_cache(): + f(self.a, self.b) + + +def main(): + result_path = sys.argv[1] + all = [ + Benchmark("eager"), + Benchmark("eager", dynamic=True), + Benchmark("inductor"), + Benchmark("inductor", is_gpu=True), + Benchmark("inductor", is_gpu=True, dynamic=True), + ] + + for benchmark in all: + benchmark.enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py new file mode 100644 index 00000000000000..d3311b7458a97c --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py @@ -0,0 +1,77 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch +import torch.nn as nn +from torch._inductor.utils import fresh_inductor_cache + + +class ListOfLinears(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + + +class Benchmark(BenchmarkBase): + def __init__( + self, ModuleClass, backend, is_gpu=False, dynamic=False, force_shape_pad=False + ): + self.ModuleClass = ModuleClass + self.backend = backend + self._name = ModuleClass.__name__ + self._is_gpu = is_gpu + self._dynamic = dynamic + self._force_shape_pad = force_shape_pad + + def name(self): + prefix = f"basic_modules_{self._name}_{self.backend}" + if self._dynamic: + prefix += "_dynamic" + if self._is_gpu: + prefix += "_gpu" + if self._force_shape_pad: + prefix += "_force_shape_pad" + return prefix + + def _prepare_once(self): + self.m = self.ModuleClass() + torch.set_float32_matmul_precision("high") + self.input = torch.ones(10, device="cuda" if self._is_gpu else "cpu") + + def _prepare(self): + torch._dynamo.reset() + + def _work(self): + with fresh_inductor_cache(), torch._inductor.config.patch( + force_shape_pad=self._force_shape_pad + ): + opt_m = torch.compile(backend=self.backend, dynamic=self._dynamic)( + self.m.cuda() if self._is_gpu else self.m + ) + opt_m(self.input) + + +def main(): + result_path = sys.argv[1] + benchmarks = [ + Benchmark(ListOfLinears, "eager"), + Benchmark(ListOfLinears, "inductor"), + Benchmark(ListOfLinears, "inductor", is_gpu=True), + Benchmark(ListOfLinears, "inductor", is_gpu=True), + Benchmark(ListOfLinears, "inductor", is_gpu=True, force_shape_pad=True), + ] + for b in benchmarks: + b.enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py similarity index 78% rename from benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py rename to benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py index 227c0bd911ada3..3bd22d18bad00e 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv.py @@ -14,7 +14,7 @@ def name(self): def description(self): return "information at https://github.com/pytorch/pytorch/issues/134133" - def prepare_once(self): + def _prepare_once(self): class M(torch.nn.Module): def forward(self, x): total = sum(t.item() for t in x) @@ -23,16 +23,18 @@ def forward(self, x): self.m = M() self.input = [torch.tensor(i + 2) for i in range(self.N)] - def prepare(self): + def _prepare(self): torch._dynamo.reset() - def work(self): + def _work(self): torch.export.export(self.m, (self.input,)) def main(): result_path = sys.argv[1] - Benchmark().enable_instruction_count().collect_all().append_results(result_path) + Benchmark().enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) if __name__ == "__main__": diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py index 92a83c609a1f2d..7957836b6a9d1d 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py @@ -1,4 +1,3 @@ -import random import sys from benchmark_base import BenchmarkBase @@ -15,17 +14,18 @@ def name(self): def description(self): return "information at https://github.com/pytorch/pytorch/pull/129893" - def prepare_once(self): + def _prepare_once(self): torch._dynamo.config.capture_scalar_outputs = True - random.seed(42) + torch.manual_seed(0) + self.splits = torch.randint(10, (self.N,)) sz = self.splits.sum().item() self.input = torch.randn(sz) - def prepare(self): + def _prepare(self): torch._dynamo.reset() - def work(self): + def _work(self): @torch.compile(fullgraph=True) def f(a, b): xs = b.tolist() @@ -39,7 +39,9 @@ def f(a, b): def main(): result_path = sys.argv[1] - Benchmark().enable_instruction_count().collect_all().append_results(result_path) + Benchmark().enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) if __name__ == "__main__": diff --git a/benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py b/benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py new file mode 100644 index 00000000000000..d011758c23ffef --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/log_benchmarking_time.py @@ -0,0 +1,17 @@ +import json +import sys + +import torch._logging.scribe as scribe + + +def main(): + duration = int(sys.argv[1]) + scribe.open_source_signpost( + subsystem="pr_time_benchmarks", + name="duration", + parameters=json.dumps(duration), + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 747a8e06e635ad..2b13a4b5c5308a 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -387,11 +387,6 @@ def get_skip_tests(suite, device, is_training: bool): skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cpu) elif device == "cuda": skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cuda) - else: - if hasattr(module, "SKIP"): - skip_tests.update(module.SKIP) - if is_training and hasattr(module, "SKIP_TRAIN"): - skip_tests.update(module.SKIP_TRAIN) skip_tests = (f"-x {name}" for name in skip_tests) skip_str = " ".join(skip_tests) @@ -438,7 +433,7 @@ def generate_commands(args, dtypes, suites, devices, compilers, output_dir): if args.enable_cpu_launcher: launcher_cmd = f"python -m torch.backends.xeon.run_cpu {args.cpu_launcher_args}" cmd = f"{launcher_cmd} benchmarks/dynamo/{suite}.py --{testing} --{dtype} -d{device} --output={output_filename}" - cmd = f"{cmd} {base_cmd} {args.extra_args} --no-skip --dashboard" + cmd = f"{cmd} {base_cmd} {args.extra_args} --dashboard" skip_tests_str = get_skip_tests(suite, device, args.training) cmd = f"{cmd} {skip_tests_str}" diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index dcbde79139b7af..b27951b16494d6 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -56,6 +56,9 @@ def setup_torchbench_cwd(): "../../torchbenchmark", "../../torchbench", "../../benchmark", + "../../../torchbenchmark", + "../../../torchbench", + "../../../benchmark", ): if exists(torchbench_dir): break @@ -135,8 +138,12 @@ def skip_models_for_cuda(self): return self._skip["device"]["cuda"] @property - def skip_models_for_freezing(self): - return self._skip["freezing"] + def skip_models_for_freezing_cuda(self): + return self._skip["freezing"]["cuda"] + + @property + def skip_models_for_freezing_cpu(self): + return self._skip["freezing"]["cpu"] @property def slow_models(self): diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index 84c7de12a36ee6..0b9a083515a69a 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -191,6 +191,7 @@ skip: - hf_Whisper - stable_diffusion_text_encoder - llava + - moco cuda: [] @@ -232,10 +233,13 @@ skip: # for these models, conv-batchnorm fusing causes big numerical churn. # Skip them + # mnasnet1_0 and shufflenet_v2_x1_0 can pass on cpu, moco cuda only. freezing: - - mnasnet1_0 - - moco - - shufflenet_v2_x1_0 + cuda: + - mnasnet1_0 + - moco + - shufflenet_v2_x1_0 + cpu: [] diff --git a/benchmarks/gpt_fast/benchmark.py b/benchmarks/gpt_fast/benchmark.py index 95e46cb87ca9bb..6270f9744f6a73 100644 --- a/benchmarks/gpt_fast/benchmark.py +++ b/benchmarks/gpt_fast/benchmark.py @@ -3,7 +3,12 @@ import dataclasses import os -from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8 +from generate import ( + get_arch_name, + run_llama2_7b_bf16, + run_llama2_7b_int8, + run_mixtral_8x7b_int8, +) import torch import torch.nn as nn @@ -24,6 +29,7 @@ class Experiment: actual: float dtype: str device: str + arch: str # GPU name for CUDA or CPU arch for CPU is_model: bool = False @@ -71,7 +77,12 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"): for _ in range(WARMUP_ITER): compiled_mod(x) - us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu + ) + us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000 flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS flops_utilization = flops_utilization / len(input_shapes) @@ -84,6 +95,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"): f"{flops_utilization:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -108,7 +120,12 @@ def run_layer_norm(device: str = "cuda"): for _ in range(WARMUP_ITER): compiled_mod(x) - us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu + ) + us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000 memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9 memory_bandwidth = memory_bandwidth / len(input_shapes) @@ -121,6 +138,7 @@ def run_layer_norm(device: str = "cuda"): f"{memory_bandwidth:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -151,9 +169,12 @@ def gather_gemv(W, score_idxs, x): for _ in range(WARMUP_ITER): compiled_fn(W, score_idxs, x) - us_per_iter = ( - benchmarker.benchmark_gpu(lambda: compiled_fn(W, score_idxs, x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu ) + us_per_iter = benchmark_fn(lambda: compiled_fn(W, score_idxs, x)) * 1000 memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9 memory_bandwidth = memory_bandwidth / len(input_shapes) @@ -166,6 +187,7 @@ def gather_gemv(W, score_idxs, x): f"{memory_bandwidth:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -186,15 +208,20 @@ def run_gemv(device: str = "cuda"): def gemv(W, x): return W.to(x.dtype) @ x - W = torch.randn(D, D, device="cuda").to(dtype=dtype) - x = torch.randn(D, device="cuda", dtype=torch.bfloat16) + W = torch.randn(D, D, device=device).to(dtype=dtype) + x = torch.randn(D, device=device, dtype=torch.bfloat16) compiled_fn = torch.compile(gemv, dynamic=False) for _ in range(WARMUP_ITER): compiled_fn(W, x) - us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_fn(W, x)) * 1000 + benchmark_fn = ( + benchmarker.benchmark_gpu + if device == "cuda" + else benchmarker.benchmark_cpu + ) + us_per_iter = benchmark_fn(lambda: compiled_fn(W, x)) * 1000 memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9 memory_bandwidth = memory_bandwidth / len(input_shapes) @@ -207,6 +234,7 @@ def gemv(W, x): f"{memory_bandwidth:.02f}", dtype_str, device, + get_arch_name(), ) ) return results @@ -252,7 +280,13 @@ def main(output_file=DEFAULT_OUTPUT_FILE): results = [] for func in all_experiments: - lst = func() + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + except AssertionError: + # This happens when torch is compiled with CUDA turning off completely + device = "cpu" + + lst = func(device) for x in lst: results.append(dataclasses.astuple(x)) diff --git a/benchmarks/gpt_fast/generate.py b/benchmarks/gpt_fast/generate.py index c2aa6e6d17fec4..56e6cff1cf8ced 100644 --- a/benchmarks/gpt_fast/generate.py +++ b/benchmarks/gpt_fast/generate.py @@ -1,5 +1,6 @@ import dataclasses import itertools +import platform import time from typing import Optional, Tuple @@ -41,6 +42,14 @@ def device_sync(device): print(f"device={device} is not yet suppported") +def get_arch_name() -> str: + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + else: + # This returns x86_64 or arm64 (for aarch64) + return platform.machine() + + def multinomial_sample_one_no_sync( probs_sort, ): # Does multinomial sampling without a cuda synchronization @@ -198,7 +207,7 @@ def run_experiment( ) -> None: print(f"Loading model {x.name}") t0 = time.time() - model = _load_model(x) + model = _load_model(x, device=device) device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") @@ -255,9 +264,11 @@ def run_llama2_7b_bf16(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 94, 1253, - 162, + 133, + ) + token_per_sec, memory_bandwidth, compilation_time = run_experiment( + model, device=device ) - token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -266,6 +277,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): f"{token_per_sec:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -275,6 +287,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): f"{memory_bandwidth:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -284,6 +297,7 @@ def run_llama2_7b_bf16(device: str = "cuda"): f"{compilation_time:.02f}", model.mode, device, + get_arch_name(), True, ), ] @@ -300,9 +314,11 @@ def run_llama2_7b_int8(device: str = "cuda"): LLaMAWeightOnlyInt8QuantHandler, 144, 957, - 172, + 136, + ) + token_per_sec, memory_bandwidth, compilation_time = run_experiment( + model, device=device ) - token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -311,6 +327,7 @@ def run_llama2_7b_int8(device: str = "cuda"): f"{token_per_sec:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -320,6 +337,7 @@ def run_llama2_7b_int8(device: str = "cuda"): f"{memory_bandwidth:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -329,6 +347,7 @@ def run_llama2_7b_int8(device: str = "cuda"): f"{compilation_time:.02f}", model.mode, device, + get_arch_name(), True, ), ] @@ -346,9 +365,11 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): MixtralMoEWeightOnlyInt8QuantHandler, 175, 1130, - 162, + 133, + ) + token_per_sec, memory_bandwidth, compilation_time = run_experiment( + model, device=device ) - token_per_sec, memory_bandwidth, compilation_time = run_experiment(model) return [ Experiment( model.name, @@ -357,6 +378,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): f"{token_per_sec:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -366,6 +388,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): f"{memory_bandwidth:.02f}", model.mode, device, + get_arch_name(), True, ), Experiment( @@ -375,6 +398,7 @@ def run_mixtral_8x7b_int8(device: str = "cuda"): f"{compilation_time:.02f}", model.mode, device, + get_arch_name(), True, ), ] diff --git a/benchmarks/sparse/triton_ops.py b/benchmarks/sparse/triton_ops.py index fad89c280f8975..2493e1e0f74019 100644 --- a/benchmarks/sparse/triton_ops.py +++ b/benchmarks/sparse/triton_ops.py @@ -28,9 +28,7 @@ def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): def _test_worker(test_func): - ms, ms_min, ms_max = benchmarker.benchmark_gpu( - test_func, warmup=500, rep=100, fast_flush=False - ) + ms, ms_min, ms_max = benchmarker.benchmark_gpu(test_func, warmup=500, rep=100) tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3) return ms, tflops diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt index 273353128baafa..3405b1defb5c55 100644 --- a/binaries/CMakeLists.txt +++ b/binaries/CMakeLists.txt @@ -35,12 +35,6 @@ if(USE_ROCM) endif() -if(USE_MPI) - caffe2_binary_target("run_plan_mpi.cc") - target_link_libraries(run_plan_mpi ${MPI_CXX_LIBRARIES}) -endif() - - caffe2_binary_target("dump_operator_names.cc") caffe2_binary_target("optimize_for_mobile.cc") diff --git a/build.bzl b/build.bzl index 9ada64dc302ad1..dbb1866ac54823 100644 --- a/build.bzl +++ b/build.bzl @@ -76,7 +76,7 @@ def define_targets(rules): ] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else [])) gen_aten_outs_cuda = ( - GENERATED_H_CUDA + GENERATED_CPP_CUDA + + GENERATED_H_CUDA + GENERATED_CPP_CUDA + GENERATED_AOTI_CUDA_CPP + aten_ufunc_generated_cuda_sources() ) @@ -320,5 +320,8 @@ GENERATED_AUTOGRAD_CPP = [ GENERATED_AOTI_CPP = [ "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp", +] + +GENERATED_AOTI_CUDA_CPP = [ "torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.cpp", ] diff --git a/build_variables.bzl b/build_variables.bzl index 98b721617b609c..d11bba1ae1f37a 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -466,13 +466,16 @@ lazy_tensor_core_python_sources = [ ] inductor_core_resources = [ + "torch/csrc/inductor/aoti_package/model_package_loader.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp", "torch/csrc/inductor/aoti_torch/shim_common.cpp", + "torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp", "torch/csrc/inductor/aoti_torch/tensor_converter.cpp", "torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp", "torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp", "torch/csrc/inductor/inductor_ops.cpp", + "torch/csrc/jit/serialization/pickle.cpp", ] libtorch_core_sources = sorted( @@ -689,7 +692,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/Utils.cu", + "torch/csrc/distributed/c10d/NanCheck.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] @@ -837,11 +840,13 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/extra_state.cpp", "torch/csrc/dynamo/framelocals_mapping.cpp", "torch/csrc/dynamo/guards.cpp", + "torch/csrc/dynamo/utils.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", "torch/csrc/fx/node.cpp", "torch/csrc/mps/Module.cpp", "torch/csrc/mtia/Module.cpp", + "torch/csrc/inductor/aoti_package/pybind.cpp", "torch/csrc/inductor/aoti_runner/pybind.cpp", "torch/csrc/inductor/aoti_eager/kernel_holder.cpp", "torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp", diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 80e172497d5e66..34577caef2ec6a 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -127,6 +127,7 @@ if(NOT BUILD_LIBTORCHLESS) if(LINUX) target_link_libraries(c10 PRIVATE Threads::Threads) + target_link_libraries(c10 PRIVATE dl) endif() if(ANDROID) diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h new file mode 100644 index 00000000000000..8724ecf88ae0f9 --- /dev/null +++ b/c10/core/CachingDeviceAllocator.h @@ -0,0 +1,131 @@ +#pragma once + +#include +#include + +#include + +namespace c10::CachingDeviceAllocator { + +struct Stat { + void increase(size_t amount) { + current += static_cast(amount); + peak = std::max(current, peak); + allocated += static_cast(amount); + } + + void decrease(size_t amount) { + current -= static_cast(amount); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + current >= 0, + "Negative tracked stat in device allocator (likely logic error)."); + freed += static_cast(amount); + } + + void reset_accumulated() { + allocated = 0; + freed = 0; + } + + void reset_peak() { + peak = current; + } + + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +using StatArray = std::array(StatType::NUM_TYPES)>; +using StatTypes = std::array(StatType::NUM_TYPES)>; + +template +void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { + for (const auto stat_type : c10::irange(stat_types.size())) { + if (stat_types[stat_type]) { + f(stat_type); + } + } +} + +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from device memory allocation. + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be + // released via device memory deallocation) + StatArray inactive_split; + + // SUM: bytes allocated by this memory alocator + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + // SUM: bytes requested by client code + StatArray requested_bytes; + + // COUNT: total number of failed calls to device malloc necessitating cache + // flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to device memory allocation + // after cache flush) + int64_t num_ooms = 0; + + // COUNT: total number of oversize blocks allocated from pool + Stat oversize_allocations; + + // COUNT: total number of oversize blocks requiring malloc + Stat oversize_segments; + + // COUNT: total number of synchronize_and_free_events() calls + int64_t num_sync_all_streams = 0; + + // COUNT: total number of device memory allocation calls. This includes both + // mapped and malloced memory. + int64_t num_device_alloc = 0; + + // COUNT: total number of device memory deallocation calls. This includes both + // un-mapped and free memory. + int64_t num_device_free = 0; + + // SIZE: maximum block size that is allowed to be split. + int64_t max_split_size = 0; +}; + +// Size pretty-printer +inline std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (static_cast(size) / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << static_cast(size) / 1048576.0; + os << " MiB"; + } else { + os << static_cast(size) / 1073741824.0; + os << " GiB"; + } + return os.str(); +} + +} // namespace c10::CachingDeviceAllocator diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 1543db0a39bcf9..70d3d7489e548d 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -145,6 +146,13 @@ void register_privateuse1_backend(const std::string& backend_name) { "torch.register_privateuse1_backend() has already been set! Current backend: ", privateuse1_backend_name); + static const std::array types = { + "cpu", "cuda", "hip", "mps", "xpu", "mtia"}; + TORCH_CHECK( + std::find(types.begin(), types.end(), backend_name) == types.end(), + "Cannot register privateuse1 backend with in-tree device name: ", + backend_name); + privateuse1_backend_name = backend_name; // Invariant: once this flag is set, privateuse1_backend_name is NEVER written // to. diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 0388234efd5b34..526e7f079ee5ad 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -149,6 +149,8 @@ const char* toString(DispatchKey t) { return "AutocastXLA"; case DispatchKey::AutocastPrivateUse1: return "AutocastPrivateUse1"; + case DispatchKey::AutocastMPS: + return "AutocastMPS"; case DispatchKey::FuncTorchBatched: return "FuncTorchBatched"; @@ -297,6 +299,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"AutocastCUDA", c10::DispatchKey::AutocastCUDA}, {"AutocastXLA", c10::DispatchKey::AutocastXLA}, {"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1}, + {"AutocastMPS", c10::DispatchKey::AutocastMPS}, {"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched}, {"BatchedNestedTensor", c10::DispatchKey::BatchedNestedTensor}, {"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 5e417ae4a3ddb1..fc5bdabd18fdd4 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -359,6 +359,7 @@ enum class DispatchKey : uint16_t { AutocastXLA, // AutocastXLA is only being used for TPUs. XLA GPUs continue to use // AutocastCUDA. + AutocastMPS, AutocastCUDA, AutocastPrivateUse1, diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index ef020071fbc2c9..ca54e1966c5e65 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -655,6 +655,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ DispatchKey::AutocastCPU, + DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, @@ -671,6 +672,7 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ DispatchKey::AutocastCPU, + DispatchKey::AutocastMPS, DispatchKey::AutocastCUDA, DispatchKey::AutocastXPU, DispatchKey::AutocastIPU, @@ -863,6 +865,7 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA); constexpr auto autocast_privateuse1_ks = DispatchKeySet(DispatchKey::AutocastPrivateUse1); + constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS); switch (t) { case BackendComponent::CPUBit: return autocast_cpu_ks; @@ -878,6 +881,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { return autocast_xla_ks; case BackendComponent::PrivateUse1Bit: return autocast_privateuse1_ks; + case BackendComponent::MPSBit: + return autocast_mps_ks; default: return DispatchKeySet(); } diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 0847af2fce5340..a7ab26a24804fb 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -239,5 +239,4 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { }; } // namespace c10 - -C10_CLANG_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 130292aaa70d6a..40bf133d2587ee 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -509,7 +509,9 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach_core( r = (pyobj_slot_.load_pyobj_interpreter())->detach(this); } if (r) { - r->set_version_counter(std::forward(version_counter)); + if (!r->is_inference()) { + r->set_version_counter(std::forward(version_counter)); + } r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); return r; } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index f1abeb0c33eae4..a8d05dddcfa26c 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2034,7 +2034,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { BackendComponent::MPSBit, BackendComponent::HIPBit, BackendComponent::XPUBit, - BackendComponent::HPUBit}); + BackendComponent::HPUBit, + BackendComponent::MTIABit}); constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense); return ts.has_any(dense_k) && ts.has_any(dense_backends); }; diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 19aedb2cbb02f7..7c1c1e02644b25 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -12,11 +12,13 @@ constexpr size_t kRoundUpPowerOfTwoIntervals = 16; CUDAAllocatorConfig::CUDAAllocatorConfig() : m_max_split_size(std::numeric_limits::max()), + m_max_non_split_rounding_size(kLargeBuffer), m_garbage_collection_threshold(0), m_pinned_num_register_threads(1), m_expandable_segments(false), m_release_lock_on_cudamalloc(false), m_pinned_use_cuda_host_register(false), + m_pinned_use_background_threads(false), m_last_allocator_settings("") { m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); } @@ -94,6 +96,27 @@ size_t CUDAAllocatorConfig::parseMaxSplitSize( return i; } +size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_non_split_rounding_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_non_split_rounding_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", ""); + } + return i; +} + size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( const std::vector& config, size_t i) { @@ -258,6 +281,9 @@ void CUDAAllocatorConfig::parseArgs(const char* env) { if (config_item_view == "max_split_size_mb") { i = parseMaxSplitSize(config, i); used_native_specific_option = true; + } else if (config_item_view == "max_non_split_rounding_mb") { + i = parseMaxNonSplitRoundingSize(config, i); + used_native_specific_option = true; } else if (config_item_view == "garbage_collection_threshold") { i = parseGarbageCollectionThreshold(config, i); used_native_specific_option = true; @@ -306,6 +332,9 @@ void CUDAAllocatorConfig::parseArgs(const char* env) { } else if (config_item_view == "pinned_num_register_threads") { i = parsePinnedNumRegisterThreads(config, i); used_native_specific_option = true; + } else if (config_item_view == "pinned_use_background_threads") { + i = parsePinnedUseBackgroundThreads(config, i); + used_native_specific_option = true; } else { TORCH_CHECK( false, "Unrecognized CachingAllocator option: ", config_item_view); @@ -363,6 +392,22 @@ size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( return i; } +size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_background_threads"); + m_pinned_use_background_threads = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_background_threads value", ""); + } + return i; +} + // General caching allocator utilities void setAllocatorSettings(const std::string& env) { CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 3106fc1b46baee..876bffcf98527a 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -46,6 +46,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_pinned_num_register_threads; } + static bool pinned_use_background_threads() { + return instance().m_pinned_use_background_threads; + } + static size_t pinned_max_register_threads() { // Based on the benchmark results, we see better allocation performance // with 8 threads. However on future systems, we may need more threads @@ -63,6 +67,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_roundup_power2_divisions; } + static size_t max_non_split_rounding_size() { + return instance().m_max_non_split_rounding_size; + } + static std::string last_allocator_settings() { std::lock_guard lock( instance().m_last_allocator_settings_mutex); @@ -90,6 +98,9 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t i, const char c); size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i); size_t parseGarbageCollectionThreshold( const std::vector& config, size_t i); @@ -106,14 +117,19 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t parsePinnedNumRegisterThreads( const std::vector& config, size_t i); + size_t parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i); std::atomic m_max_split_size; + std::atomic m_max_non_split_rounding_size; std::vector m_roundup_power2_divisions; std::atomic m_garbage_collection_threshold; std::atomic m_pinned_num_register_threads; std::atomic m_expandable_segments; std::atomic m_release_lock_on_cudamalloc; std::atomic m_pinned_use_cuda_host_register; + std::atomic m_pinned_use_background_threads; std::string m_last_allocator_settings; std::mutex m_last_allocator_settings_mutex; }; diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4a3a2c7545c54c..a67a720717bb7f 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -6,11 +6,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include @@ -27,7 +27,6 @@ #include #include #include -#include #include #include #include @@ -44,6 +43,8 @@ C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace cuda::CUDACachingAllocator { +using namespace c10::CachingDeviceAllocator; + // Included here as this is externally used in CUDAAllocatorConfig const size_t kLargeBuffer = 20971520; // "large" allocations may be packed in 20 MiB blocks @@ -134,47 +135,13 @@ namespace { using stream_set = ska::flat_hash_set; -using StatTypes = std::array(StatType::NUM_TYPES)>; - -void increase_stat(Stat& stat, size_t amount) { - stat.current += static_cast(amount); - stat.peak = std::max(stat.current, stat.peak); - stat.allocated += static_cast(amount); -} - -void decrease_stat(Stat& stat, size_t amount) { - stat.current -= static_cast(amount); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - stat.current >= 0, - "Negative tracked stat in CUDA allocator (likely logic error)."); - stat.freed += static_cast(amount); -} - -void reset_accumulated_stat(Stat& stat) { - stat.allocated = 0; - stat.freed = 0; -} - -void reset_peak_stat(Stat& stat) { - stat.peak = stat.current; -} - -template -void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { - for (const auto stat_type : c10::irange(stat_types.size())) { - if (stat_types[stat_type]) { - f(stat_type); - } - } -} - void decrease_stat_array( StatArray& stat_array, size_t amount, const StatTypes& stat_types) { for_each_selected_stat_type( stat_types, [&stat_array, amount](size_t stat_type) { - decrease_stat(stat_array[stat_type], amount); + stat_array[stat_type].decrease(amount); }); } @@ -1424,16 +1391,16 @@ class DeviceCachingAllocator { // A new split inactive block is being created from a previously unsplit // block, size remaining->size bytes. for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.inactive_split_bytes[stat_type], remaining->size); - increase_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].increase(remaining->size); + stats.inactive_split[stat_type].increase(1); }); } } else if (already_split && !block->expandable_segment_) { // An already-split block is becoming active for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - decrease_stat(stats.inactive_split_bytes[stat_type], block->size); - decrease_stat(stats.inactive_split[stat_type], 1); + stats.inactive_split_bytes[stat_type].decrease(block->size); + stats.inactive_split[stat_type].decrease(1); }); } @@ -1454,14 +1421,20 @@ class DeviceCachingAllocator { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { - increase_stat(stats.allocation[stat_type], 1); - increase_stat(stats.allocated_bytes[stat_type], block->size); - increase_stat(stats.active[stat_type], 1); - increase_stat(stats.active_bytes[stat_type], block->size); - increase_stat(stats.requested_bytes[stat_type], block->requested_size); + stats.allocation[stat_type].increase(1); + stats.allocated_bytes[stat_type].increase(block->size); + stats.active[stat_type].increase(1); + stats.active_bytes[stat_type].increase(block->size); + stats.requested_bytes[stat_type].increase(block->requested_size); }); if (block->size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.increase(1); + + auto allocated_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes); + allocated_bytes_gauge.record( + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current); c10::reportMemoryUsageToProfiler( block->ptr, @@ -1487,9 +1460,14 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.allocation[stat_type], 1); - decrease_stat(stats.allocated_bytes[stat_type], block->size); + stats.allocation[stat_type].decrease(1); + stats.allocated_bytes[stat_type].decrease(block->size); }); + auto allocated_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes); + allocated_bytes_gauge.record( + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current); record_trace( TraceEntry::FREE_REQUESTED, @@ -1500,7 +1478,7 @@ class DeviceCachingAllocator { context ? context : block->context_when_allocated); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_allocations, 1); + stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { if (C10_UNLIKELY(!captures_underway.empty())) { @@ -1628,15 +1606,15 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_accumulated_stat(stats.allocation[statType]); - reset_accumulated_stat(stats.segment[statType]); - reset_accumulated_stat(stats.active[statType]); - reset_accumulated_stat(stats.inactive_split[statType]); - reset_accumulated_stat(stats.allocated_bytes[statType]); - reset_accumulated_stat(stats.reserved_bytes[statType]); - reset_accumulated_stat(stats.active_bytes[statType]); - reset_accumulated_stat(stats.inactive_split_bytes[statType]); - reset_accumulated_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_accumulated(); + stats.segment[statType].reset_accumulated(); + stats.active[statType].reset_accumulated(); + stats.inactive_split[statType].reset_accumulated(); + stats.allocated_bytes[statType].reset_accumulated(); + stats.reserved_bytes[statType].reset_accumulated(); + stats.active_bytes[statType].reset_accumulated(); + stats.inactive_split_bytes[statType].reset_accumulated(); + stats.requested_bytes[statType].reset_accumulated(); } stats.num_alloc_retries = 0; @@ -1644,8 +1622,8 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; - reset_accumulated_stat(stats.oversize_allocations); - reset_accumulated_stat(stats.oversize_segments); + stats.oversize_allocations.reset_accumulated(); + stats.oversize_segments.reset_accumulated(); } /** Resets the historical peak stats for the device **/ @@ -1654,18 +1632,18 @@ class DeviceCachingAllocator { for (const auto statType : c10::irange(static_cast(StatType::NUM_TYPES))) { - reset_peak_stat(stats.allocation[statType]); - reset_peak_stat(stats.segment[statType]); - reset_peak_stat(stats.active[statType]); - reset_peak_stat(stats.inactive_split[statType]); - reset_peak_stat(stats.allocated_bytes[statType]); - reset_peak_stat(stats.reserved_bytes[statType]); - reset_peak_stat(stats.active_bytes[statType]); - reset_peak_stat(stats.inactive_split_bytes[statType]); - reset_peak_stat(stats.requested_bytes[statType]); + stats.allocation[statType].reset_peak(); + stats.segment[statType].reset_peak(); + stats.active[statType].reset_peak(); + stats.inactive_split[statType].reset_peak(); + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.inactive_split_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); } - reset_peak_stat(stats.oversize_allocations); - reset_peak_stat(stats.oversize_segments); + stats.oversize_allocations.reset_peak(); + stats.oversize_segments.reset_peak(); } /* Checkpoint the state of a private pool necessary to return it to its @@ -2277,8 +2255,13 @@ class DeviceCachingAllocator { total_allocated_memory += mapped_range.size; StatTypes stat_types = get_stat_types_for_pool(*to_map->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - increase_stat(stats.reserved_bytes[stat_type], mapped_range.size); + stats.reserved_bytes[stat_type].increase(mapped_range.size); }); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); stats.num_device_alloc++; record_trace( @@ -2384,27 +2367,23 @@ class DeviceCachingAllocator { // inactive_split if (!block->expandable_segment_) { if (net_change_inactive_split_blocks > 0) { - increase_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].increase( static_cast(net_change_inactive_split_blocks)); } else if (net_change_inactive_split_blocks < 0) { - decrease_stat( - stats.inactive_split[stat_type], + stats.inactive_split[stat_type].decrease( static_cast(-net_change_inactive_split_blocks)); } if (net_change_inactive_split_size > 0) { - increase_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].increase( static_cast(net_change_inactive_split_size)); } else if (net_change_inactive_split_size < 0) { - decrease_stat( - stats.inactive_split_bytes[stat_type], + stats.inactive_split_bytes[stat_type].decrease( static_cast(-net_change_inactive_split_size)); } } - decrease_stat(stats.active[stat_type], 1); - decrease_stat(stats.active_bytes[stat_type], original_block_size); - decrease_stat(stats.requested_bytes[stat_type], requested_size); + stats.active[stat_type].decrease(1); + stats.active_bytes[stat_type].decrease(original_block_size); + stats.requested_bytes[stat_type].decrease(requested_size); }); } @@ -2548,7 +2527,8 @@ class DeviceCachingAllocator { return false; // Allow oversized block size to be rounded up but within a limit if ((p.size() >= CUDAAllocatorConfig::max_split_size()) && - ((*it)->size >= p.size() + kLargeBuffer)) + ((*it)->size >= + p.size() + CUDAAllocatorConfig::max_non_split_rounding_size())) return false; p.block = *it; pool.blocks.erase(it); @@ -2611,7 +2591,7 @@ class DeviceCachingAllocator { while (it != large_blocks.blocks.end()) { Block* block = *it; ++it; - if (!block->is_split() && + if (!block->is_split() && !block->expandable_segment_ && static_cast(block->gc_count()) >= age_threshold) { block_freed = true; gc_reclaimed += block->size; @@ -2674,7 +2654,12 @@ class DeviceCachingAllocator { // any potential exceptions in the cudaMallocMaybeCapturing function. auto sg = c10::make_scope_exit([&]() { lock.lock(); }); lock.unlock(); - p.err = cudaMallocMaybeCapturing(&ptr, size); + } + auto active_pool = MemPoolContext::getActiveMemPool(); + if (active_pool && active_pool->allocator() && + p.pool->owner_PrivatePool) { + ptr = active_pool->allocator()->raw_alloc(size); + p.err = ptr ? cudaSuccess : cudaErrorMemoryAllocation; } else { p.err = cudaMallocMaybeCapturing(&ptr, size); } @@ -2711,11 +2696,16 @@ class DeviceCachingAllocator { total_allocated_memory += size; p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { - increase_stat(stats.segment[stat_type], 1); - increase_stat(stats.reserved_bytes[stat_type], size); + stats.segment[stat_type].increase(1); + stats.reserved_bytes[stat_type].increase(size); }); if (size >= CUDAAllocatorConfig::max_split_size()) - increase_stat(stats.oversize_segments, 1); + stats.oversize_segments.increase(1); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); // p.block came from new, not cudaMalloc. It should not be nullptr here. TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); @@ -2749,7 +2739,8 @@ class DeviceCachingAllocator { ? CUDAAllocatorConfig::max_split_size() : key.size; auto it = pool.blocks.lower_bound(&key); - if (it == pool.blocks.end() || (*it)->stream != p.stream()) { + if (it == pool.blocks.end() || (*it)->stream != p.stream() || + (*it)->expandable_segment_) { // No single block is large enough; free multiple oversize blocks, // starting with the largest if (it == pool.blocks.begin()) @@ -2761,12 +2752,15 @@ class DeviceCachingAllocator { ((*it)->size >= CUDAAllocatorConfig::max_split_size()) && ((*it)->stream == p.stream())) { auto cur = it; - totalReleased += (*it)->size; - if (it != pool.blocks.begin()) { + bool is_first = cur == pool.blocks.begin(); + if (!is_first) { --it; + } + if (!(*cur)->expandable_segment_) { release_block(*cur, context); - } else { - release_block(*cur, context); + totalReleased += (*cur)->size; + } + if (is_first) { break; } } @@ -2846,12 +2840,17 @@ class DeviceCachingAllocator { StatTypes stat_types = get_stat_types_for_pool(*pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.segment[stat_type], 1); - decrease_stat(stats.reserved_bytes[stat_type], block->size); + stats.segment[stat_type].decrease(1); + stats.reserved_bytes[stat_type].decrease(block->size); }); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); if (block->size >= CUDAAllocatorConfig::max_split_size()) - decrease_stat(stats.oversize_segments, 1); + stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; } @@ -2903,8 +2902,13 @@ class DeviceCachingAllocator { total_allocated_memory -= unmapped.size; StatTypes stat_types = get_stat_types_for_pool(*block->pool); for_each_selected_stat_type(stat_types, [&](size_t stat_type) { - decrease_stat(stats.reserved_bytes[stat_type], unmapped.size); + stats.reserved_bytes[stat_type].decrease(unmapped.size); }); + auto reserved_bytes_gauge = + STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); + reserved_bytes_gauge.record( + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current); if (block->pool->owner_PrivatePool) { // The cudaFreed block belonged to a CUDA graph's PrivatePool. @@ -3732,25 +3736,6 @@ void local_raw_delete(void* ptr) { } } // namespace Native -// Size pretty-printer -std::string format_size(uint64_t size) { - std::ostringstream os; - os.precision(2); - os << std::fixed; - if (size <= 1024) { - os << size << " bytes"; - } else if (size <= 1048576) { - os << (static_cast(size) / 1024.0); - os << " KiB"; - } else if (size <= 1073741824ULL) { - os << static_cast(size) / 1048576.0; - os << " MiB"; - } else { - os << static_cast(size) / 1073741824.0; - os << " GiB"; - } - return os.str(); -} namespace CudaMallocAsync { // If this is put in its own header file, it gets incorrectly renamed in HIPify. diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 72617bcaf3a944..70385654201f50 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -48,74 +48,11 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); namespace c10::cuda::CUDACachingAllocator { -extern const size_t kLargeBuffer; - -struct Stat { - int64_t current = 0; - int64_t peak = 0; - int64_t allocated = 0; - int64_t freed = 0; -}; - -enum struct StatType : uint64_t { - AGGREGATE = 0, - SMALL_POOL = 1, - LARGE_POOL = 2, - NUM_TYPES = 3 // remember to update this whenever a new stat type is added -}; +// Preserved only for BC reasons +// NOLINTNEXTLINE(misc-unused-using-decls) +using c10::CachingDeviceAllocator::DeviceStats; -typedef std::array(StatType::NUM_TYPES)> StatArray; - -// Struct containing memory allocator summary statistics for a device. -struct DeviceStats { - // COUNT: allocations requested by client code - StatArray allocation; - // COUNT: number of allocated segments from cudaMalloc(). - StatArray segment; - // COUNT: number of active memory blocks (allocated or used by stream) - StatArray active; - // COUNT: number of inactive, split memory blocks (unallocated but can't be - // released via cudaFree) - StatArray inactive_split; - - // SUM: bytes allocated by this memory alocator - StatArray allocated_bytes; - // SUM: bytes reserved by this memory allocator (both free and used) - StatArray reserved_bytes; - // SUM: bytes within active memory blocks - StatArray active_bytes; - // SUM: bytes within inactive, split memory blocks - StatArray inactive_split_bytes; - // SUM: bytes requested by client code - StatArray requested_bytes; - - // COUNT: total number of failed calls to CUDA malloc necessitating cache - // flushes. - int64_t num_alloc_retries = 0; - - // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush) - int64_t num_ooms = 0; - - // COUNT: total number of oversize blocks allocated from pool - Stat oversize_allocations; - - // COUNT: total number of oversize blocks requiring malloc - Stat oversize_segments; - - // COUNT: total number of synchronize_and_free_events() calls - int64_t num_sync_all_streams = 0; - - // COUNT: total number of CUDA allocation calls. This includes both cuMemMap - // and cudaMalloc. - int64_t num_device_alloc = 0; - - // COUNT: total number of CUDA free calls. This includes both cuMemUnmap - // and cudaFree. - int64_t num_device_free = 0; - - // SIZE: maximum block size that is allowed to be split. - int64_t max_split_size = 0; -}; +extern const size_t kLargeBuffer; typedef std::shared_ptr (*CreateContextFn)(); @@ -247,9 +184,6 @@ enum struct RecordContext { ALL = 3, // additionally record stacks for when something is freed }; -// Size pretty-printer -std::string format_size(uint64_t size); - using OutOfMemoryObserver = std::functionrecordStream(dataPtr, stream); } -inline DeviceStats getDeviceStats(c10::DeviceIndex device) { +inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) { return get()->getDeviceStats(device); } diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 3fe414b55e30a2..3a7b485ebce223 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -11,6 +11,8 @@ namespace c10::cuda::CUDACachingAllocator::CudaMallocAsync { +using namespace c10::CachingDeviceAllocator; + #if CUDA_VERSION >= 11040 // CUDA device allocator that uses cudaMallocAsync to implement // the same interface as CUDACachingAllocator.cpp. diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 51b54105297082..ab6f2b38cf6be7 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -345,6 +345,7 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) // Those platforms do not support assert() #define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) #define SYCL_KERNEL_ASSERT(cond) #elif defined(_MSC_VER) #if defined(NDEBUG) @@ -372,6 +373,16 @@ __host__ __device__ static_cast(__LINE__)), \ 0); \ } +// TODO: This doesn't assert the message because I (chilli) couldn't figure out +// a nice way to convert a char* to a wchar_t* +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } #define SYCL_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ (void)(_wassert( \ @@ -413,6 +424,7 @@ __host__ __device__ // ROCm disable kernel assert by default #if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) #define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) #define SYCL_KERNEL_ASSERT(cond) #else #define CUDA_KERNEL_ASSERT(cond) \ @@ -420,6 +432,11 @@ __host__ __device__ __assert_fail( \ #cond, __FILE__, static_cast(__LINE__), __func__); \ } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + msg, __FILE__, static_cast(__LINE__), __func__); \ + } #define SYCL_KERNEL_ASSERT(cond) \ if (C10_UNLIKELY(!(cond))) { \ __assert_fail( \ diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp index c655d7953a0b4f..2fc569135e267c 100644 --- a/c10/mobile/CPUProfilingAllocator.cpp +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include diff --git a/c10/test/core/Device_test.cpp b/c10/test/core/Device_test.cpp index 0524ad08e8ec37..3874ea98f2d379 100644 --- a/c10/test/core/Device_test.cpp +++ b/c10/test/core/Device_test.cpp @@ -54,3 +54,12 @@ TEST(DeviceTypeTest, PrivateUseOneDeviceType) { ASSERT_EQ(c10::get_privateuse1_backend(true), "my_privateuse1_backend"); ASSERT_EQ(c10::get_privateuse1_backend(false), "MY_PRIVATEUSE1_BACKEND"); } + +TEST(DeviceTypeTest, PrivateUseOneRegister) { + ASSERT_THROW(c10::register_privateuse1_backend("cpu"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("cuda"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("hip"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("mps"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("xpu"), c10::Error); + ASSERT_THROW(c10::register_privateuse1_backend("mtia"), c10::Error); +} diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h index 88a6b849d37bf9..bad374cbd4353e 100644 --- a/c10/util/BFloat16-math.h +++ b/c10/util/BFloat16-math.h @@ -68,6 +68,12 @@ template < inline T expm1(T a) { return std::expm1(float(a)); } +template < + typename T, + typename std::enable_if_t, int> = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} template < typename T, typename std::enable_if_t, int> = 0> @@ -237,10 +243,9 @@ C10_HOST_DEVICE inline T nextafter(T from, T to) { // Reference: // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c using int_repr_t = uint16_t; - using float_t = T; constexpr uint8_t bits = 16; union { - float_t f; + T f; int_repr_t i; } ufrom = {from}, uto = {to}; diff --git a/c10/util/Gauge.cpp b/c10/util/Gauge.cpp new file mode 100644 index 00000000000000..4f3a2214e725be --- /dev/null +++ b/c10/util/Gauge.cpp @@ -0,0 +1,79 @@ +#include + +#include + +#include +#include +#include +#include +#include + +namespace c10::monitor { + +namespace detail { +namespace { +using GaugeBackendFactories = + std::vector>; + +Synchronized& gaugeBackendFactories() { + static auto instance = new Synchronized(); + return *instance; +} +} // namespace + +class GaugeImpl { + public: + static GaugeImpl& getInstance(std::string_view key) { + static auto& implMapSynchronized = *new Synchronized< + std::unordered_map>>(); + + return *implMapSynchronized.withLock([&](auto& implMap) { + if (auto implIt = implMap.find(std::string(key)); + implIt != implMap.end()) { + return implIt->second.get(); + } + + auto [implIt, emplaceSuccess] = implMap.emplace( + std::string{key}, std::unique_ptr(new GaugeImpl(key))); + + assert(emplaceSuccess); + + return implIt->second.get(); + }); + } + + void record(int64_t value) { + for (auto& backend : backends_) { + backend->record(value); + } + } + + private: + explicit GaugeImpl(std::string_view key) { + auto factoriesCopy = gaugeBackendFactories().withLock( + [](auto& factories) { return factories; }); + for (const auto& factory : factoriesCopy) { + if (auto backend = factory->create(key)) { + backends_.push_back(std::move(backend)); + } + } + } + + SmallVector> backends_; +}; + +void registerGaugeBackend(std::unique_ptr backend) { + gaugeBackendFactories().withLock( + [&](auto& backends) { backends.push_back(std::move(backend)); }); +} + +} // namespace detail + +GaugeHandle::GaugeHandle(std::string_view key) + : impl_(detail::GaugeImpl::getInstance(key)) {} + +void GaugeHandle::record(int64_t value) { + impl_.record(value); +} + +} // namespace c10::monitor diff --git a/c10/util/Gauge.h b/c10/util/Gauge.h new file mode 100644 index 00000000000000..f92ecd986bee1b --- /dev/null +++ b/c10/util/Gauge.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10::monitor { +namespace detail { + +class GaugeImpl; + +class GaugeBackendIf { + public: + virtual ~GaugeBackendIf() = default; + virtual void record(int64_t value) noexcept = 0; +}; + +class GaugeBackendFactoryIf { + public: + virtual ~GaugeBackendFactoryIf() = default; + + // May return nullptr if the gauge will be ignored by the given backend. + virtual std::unique_ptr create( + std::string_view key) noexcept = 0; +}; + +void C10_API registerGaugeBackend(std::unique_ptr); +} // namespace detail + +// A handle to a Gauge. +class C10_API GaugeHandle { + public: + explicit GaugeHandle(std::string_view key); + void record(int64_t value); + + private: + detail::GaugeImpl& impl_; +}; + +} // namespace c10::monitor + +#define STATIC_GAUGE(_key) \ + []() -> ::c10::monitor::GaugeHandle& { \ + static ::c10::monitor::GaugeHandle handle(#_key); \ + return handle; \ + }() diff --git a/c10/util/StringUtil.cpp b/c10/util/StringUtil.cpp index 1f5254a3dedaa2..b92802d956c806 100644 --- a/c10/util/StringUtil.cpp +++ b/c10/util/StringUtil.cpp @@ -41,15 +41,14 @@ std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString); #ifndef _WIN32 -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") // TODO (huydhn) https://en.cppreference.com/w/cpp/header/codecvt has been // deprecated in C++17 but there is no alternative yet, so I just ack it std::ostream& _strFromWide(std::ostream& ss, const std::wstring& wString) { std::wstring_convert> converter; return _str(ss, converter.to_bytes(wString)); } -#pragma GCC diagnostic pop +C10_DIAGNOSTIC_POP() #else // #ifndef _WIN32 // The WIN32 implementation of wstring_convert leaks memory; see diff --git a/c10/util/WaitCounter.cpp b/c10/util/WaitCounter.cpp index 03c2ac72939dc8..3941942dfb3500 100644 --- a/c10/util/WaitCounter.cpp +++ b/c10/util/WaitCounter.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -8,6 +9,10 @@ #include #include +#ifndef _WIN32 +#include +#endif + namespace c10::monitor { namespace detail { @@ -19,6 +24,58 @@ Synchronized& waitCounterBackendFactories() { static auto instance = new Synchronized(); return *instance; } + +class DynamicBackendWrapper : public WaitCounterBackendIf { + public: + explicit DynamicBackendWrapper(WaitCounterDynamicBackend impl) + : impl_{impl} {} + ~DynamicBackendWrapper() override { + impl_.destroy(impl_.self); + } + + intptr_t start(std::chrono::steady_clock::time_point now) noexcept override { + return impl_.start( + impl_.self, + std::chrono::duration_cast( + now.time_since_epoch()) + .count()); + } + + void stop(std::chrono::steady_clock::time_point now, intptr_t ctx) noexcept + override { + return impl_.stop( + impl_.self, + std::chrono::duration_cast( + now.time_since_epoch()) + .count(), + ctx); + } + + private: + WaitCounterDynamicBackend impl_; +}; + +std::unique_ptr getDynamicBackend(std::string_view key) { + static auto dynamicBackendInit = + reinterpret_cast([]() -> void* { +#ifndef _WIN32 + return dlsym( + RTLD_DEFAULT, + std::string(kWaitCounterDynamicBackendInitFn).c_str()); +#else + return nullptr; +#endif + }()); + if (!dynamicBackendInit) { + return nullptr; + } + WaitCounterDynamicBackend backend; + dynamicBackendInit(&backend, &key[0], key.size()); + if (!backend.self) { + return nullptr; + } + return std::make_unique(backend); +} } // namespace class WaitCounterImpl { @@ -70,6 +127,9 @@ class WaitCounterImpl { backends_.push_back(std::move(backend)); } } + if (auto backend = getDynamicBackend(key)) { + backends_.push_back(std::move(backend)); + } } SmallVector> backends_; @@ -80,6 +140,12 @@ void registerWaitCounterBackend( waitCounterBackendFactories().withLock( [&](auto& factories) { factories.push_back(std::move(factory)); }); } + +std::vector> +getRegisteredWaitCounterBackends() { + return waitCounterBackendFactories().withLock( + [](auto& factories) { return factories; }); +} } // namespace detail WaitCounterHandle::WaitCounterHandle(std::string_view key) diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index dba8f82f3ca365..504e88720a9c12 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -36,6 +36,9 @@ class WaitCounterBackendFactoryIf { C10_API void registerWaitCounterBackend( std::unique_ptr); + +C10_API std::vector> +getRegisteredWaitCounterBackends(); } // namespace detail // A handle to a wait counter. diff --git a/c10/util/WaitCounterDynamicBackend.h b/c10/util/WaitCounterDynamicBackend.h new file mode 100644 index 00000000000000..ecbdc7f09f0341 --- /dev/null +++ b/c10/util/WaitCounterDynamicBackend.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace c10::monitor::detail { + +struct WaitCounterDynamicBackend { + void* self{nullptr}; + intptr_t (*start)(void* self, int64_t nowUs){nullptr}; + void (*stop)(void* self, int64_t nowUs, intptr_t ctx){nullptr}; + void (*destroy)(void* self){nullptr}; +}; + +using WaitCounterDynamicBackendInit = + void (*)(WaitCounterDynamicBackend*, const char* key, std::size_t keyLen); + +// This name needs to be updated if anything in the API above is changed. +constexpr std::string_view kWaitCounterDynamicBackendInitFn = + "c10_monitor_wait_counter_dynamic_backend_init_v1"; +} // namespace c10::monitor::detail diff --git a/c10/util/build.bzl b/c10/util/build.bzl index fb06a1517834d1..a6f95ae7516d36 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -43,6 +43,10 @@ def define_targets(rules): "//c10:using_glog": ["@com_github_glog//:glog"], "//conditions:default": [], }), + linkopts = rules.select({ + "@bazel_tools//src/conditions:windows": [], + "//conditions:default": ["-ldl"], + }), # This library uses flags and registration. Do not let the # linker remove them. alwayslink = True, diff --git a/c10/util/complex.h b/c10/util/complex.h index af810a780dd547..c08e10aa0f2933 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -261,19 +261,19 @@ struct alignas(sizeof(T) * 2) complex { #endif if (abs_c >= abs_d) { - if (abs_c == 0 && abs_d == 0) { + if (abs_c == U(0) && abs_d == U(0)) { /* divide by zeros should yield a complex inf or nan */ real_ = a / abs_c; imag_ = b / abs_d; } else { auto rat = d / c; - auto scl = 1.0 / (c + d * rat); + auto scl = U(1.0) / (c + d * rat); real_ = (a + b * rat) * scl; imag_ = (b - a * rat) * scl; } } else { auto rat = c / d; - auto scl = 1.0 / (d + c * rat); + auto scl = U(1.0) / (d + c * rat); real_ = (a * rat + b) * scl; imag_ = (b * rat - a) * scl; } diff --git a/c10/util/irange.h b/c10/util/irange.h index 3d7c607a1e9033..2719a82075cc96 100644 --- a/c10/util/irange.h +++ b/c10/util/irange.h @@ -2,7 +2,6 @@ #pragma once -#include #include #include diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index da57191fa1601c..067ddf5f82a4d7 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -9,6 +9,8 @@ namespace c10::xpu::XPUCachingAllocator { +using namespace c10::CachingDeviceAllocator; + // newly allocated memory with 512-byte alignment. constexpr size_t kDeviceAlignment = 512; // all sizes are rounded to at least 512 bytes @@ -117,6 +119,7 @@ struct AllocParams { BlockPool* pool; size_t alloc_size; Block* block; + StatTypes stat_types = {}; }; } // anonymous namespace @@ -124,6 +127,7 @@ struct AllocParams { class DeviceCachingAllocator { private: mutable std::recursive_mutex mutex; + DeviceStats stats; BlockPool large_blocks; // unallocated cached blocks larger than 1 MB BlockPool small_blocks; // unallocated cached blocks 1 MB or smaller ska::flat_hash_set active_blocks; // allocated or in use by a stream @@ -164,6 +168,8 @@ class DeviceCachingAllocator { !block->allocated && block->event_count == 0 && block->stream_uses.empty()); + size_t original_block_size = block->size; + size_t requested_size = block->requested_size; auto& pool = *block->pool; const std::array merge_candidates = {block->prev, block->next}; for (Block* merge_candidate : merge_candidates) { @@ -173,6 +179,12 @@ class DeviceCachingAllocator { active_blocks.erase(block); bool inserted = pool.blocks.insert(block).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); + + StatTypes stat_types = get_stat_types_for_pool(pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + stats.active_bytes[stat_type].decrease(original_block_size); + stats.requested_bytes[stat_type].decrease(requested_size); + }); } void process_events() { @@ -250,6 +262,9 @@ class DeviceCachingAllocator { return false; } p.block = new Block(device, p.queue(), size, p.pool, ptr); + for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { + stats.reserved_bytes[stat_type].increase(size); + }); return true; } @@ -281,6 +296,12 @@ class DeviceCachingAllocator { sycl::free(block->ptr, xpu::get_device_context()); auto* pool = block->pool; pool->blocks.erase(block); + + StatTypes stat_types = get_stat_types_for_pool(*pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + stats.reserved_bytes[stat_type].decrease(block->size); + }); + delete block; } @@ -314,6 +335,14 @@ class DeviceCachingAllocator { } } + StatTypes get_stat_types_for_pool(const BlockPool& pool) { + StatTypes stat_types = {}; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast( + pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true; + return stat_types; + } + Block* alloc_found_block( AllocParams params, size_t orig_size, @@ -350,6 +379,12 @@ class DeviceCachingAllocator { bool inserted = active_blocks.insert(block).second; TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted) + for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { + stats.allocated_bytes[stat_type].increase(block->size); + stats.active_bytes[stat_type].increase(block->size); + stats.requested_bytes[stat_type].increase(block->requested_size); + }); + return block; } @@ -376,6 +411,7 @@ class DeviceCachingAllocator { auto& pool = get_pool(size); const size_t alloc_size = get_allocation_size(size); AllocParams params(device, size, &queue, &pool, alloc_size); + params.stat_types = get_stat_types_for_pool(pool); // First, try to get a block from the existing pool. bool block_found = get_free_block(params); @@ -384,9 +420,32 @@ class DeviceCachingAllocator { block_found = alloc_block(params) || (release_cached_blocks() && alloc_block(params)); } - TORCH_CHECK( - block_found, - "XPU out of memory, please use `empty_cache` to release all unoccupied cached memory."); + if (!block_found) { + c10::xpu::DeviceProp device_prop; + c10::xpu::get_device_properties(&device_prop, device); + auto device_total = device_prop.global_mem_size; + auto allocated_bytes = + stats.allocated_bytes[static_cast(StatType::AGGREGATE)] + .current; + auto reserved_bytes = + stats.reserved_bytes[static_cast(StatType::AGGREGATE)] + .current; + TORCH_CHECK_WITH( + OutOfMemoryError, + false, + "XPU out of memory. Tried to allocate ", + format_size(alloc_size), + ". GPU ", + static_cast(device), + " has a total capacity of ", + format_size(device_total), + ". Of the allocated memory ", + format_size(allocated_bytes), + " is allocated by PyTorch, and ", + format_size(reserved_bytes - allocated_bytes), + " is reserved by PyTorch but unallocated.", + " Please use `empty_cache` to release all unoccupied cached memory."); + } bool split_remainder = should_split(params.block, params.size()); return alloc_found_block(std::move(params), orig_size, split_remainder); } @@ -395,6 +454,11 @@ class DeviceCachingAllocator { std::scoped_lock lock(mutex); block->allocated = false; + StatTypes stat_types = get_stat_types_for_pool(*block->pool); + for_each_selected_stat_type(stat_types, [&](size_t stat_type) { + stats.allocated_bytes[stat_type].decrease(block->size); + }); + if (!block->stream_uses.empty()) { insert_events(block); } else { @@ -414,6 +478,35 @@ class DeviceCachingAllocator { std::scoped_lock lock(mutex); release_cached_blocks(); } + + DeviceStats getStats() { + std::scoped_lock lock(mutex); + return stats; + } + + void resetAccumulatedStats() { + std::scoped_lock lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats.allocated_bytes[statType].reset_accumulated(); + stats.reserved_bytes[statType].reset_accumulated(); + stats.active_bytes[statType].reset_accumulated(); + stats.requested_bytes[statType].reset_accumulated(); + } + } + + void resetPeakStats() { + std::scoped_lock lock(mutex); + + for (const auto statType : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats.allocated_bytes[statType].reset_peak(); + stats.reserved_bytes[statType].reset_peak(); + stats.active_bytes[statType].reset_peak(); + stats.requested_bytes[statType].reset_peak(); + } + } }; void local_raw_delete(void* ptr); @@ -547,6 +640,30 @@ class XPUAllocator : public Allocator { void copy_data(void* dest, const void* src, std::size_t count) const final { xpu::getCurrentXPUStream().queue().memcpy(dest, src, count); } + + void assertValidDevice(DeviceIndex device) { + const auto device_num = device_allocators.size(); + TORCH_CHECK( + 0 <= device && device < static_cast(device_num), + "Invalid device argument ", + device, + ": did you call init?"); + } + + DeviceStats getDeviceStats(DeviceIndex device) { + assertValidDevice(device); + return device_allocators[device]->getStats(); + } + + void resetPeakStats(DeviceIndex device) { + assertValidDevice(device); + device_allocators[device]->resetPeakStats(); + } + + void resetAccumulatedStats(DeviceIndex device) { + assertValidDevice(device); + device_allocators[device]->resetAccumulatedStats(); + } }; static XPUAllocator allocator; @@ -567,6 +684,18 @@ void emptyCache() { return allocator.emptyCache(); } +void resetPeakStats(DeviceIndex device) { + return allocator.resetPeakStats(device); +} + +void resetAccumulatedStats(DeviceIndex device) { + return allocator.resetAccumulatedStats(device); +} + +DeviceStats getDeviceStats(DeviceIndex device) { + return allocator.getDeviceStats(device); +} + void* raw_alloc(size_t size) { return allocator.raw_alloc(size); } diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index 683654263a473d..6cdc8c8c71a6c9 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace c10::xpu::XPUCachingAllocator { @@ -11,6 +11,13 @@ C10_XPU_API void init(DeviceIndex device_count); C10_XPU_API void emptyCache(); +C10_XPU_API void resetPeakStats(DeviceIndex device); + +C10_XPU_API void resetAccumulatedStats(DeviceIndex device); + +C10_XPU_API c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + DeviceIndex device); + C10_XPU_API void* raw_alloc(size_t size); C10_XPU_API void raw_delete(void* ptr); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index d44a8da210462f..2160399a3ea296 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -767,7 +767,7 @@ if(NOT MSVC) set_source_files_properties(${PROJECT_SOURCE_DIR}/torch/csrc/distributed/c10d/socket.cpp PROPERTIES COMPILE_OPTIONS "-Wno-error=deprecated") endif() -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) +if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) @@ -781,7 +781,7 @@ if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_VULKAN AND NOT USE_IOS set_source_files_properties(${source_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes") continue() endif() - string(FIND "${source_file}" "caffe2" res) + string(FIND "${source_file}" "embedding_lookup_idx_avx2.cc" res) if(res GREATER -1) set_source_files_properties(${source_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes") endif() @@ -927,7 +927,6 @@ elseif(USE_CUDA) set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_cuda) # see cmake/public/utils.cmake target_compile_definitions(torch_cuda PRIVATE USE_CUDA) - target_link_libraries(torch_cuda PRIVATE fmt::fmt-header-only) if(USE_CUFILE) target_link_libraries(torch_cuda PRIVATE torch::cufile) @@ -1334,7 +1333,6 @@ if(USE_ROCM) ${ROCM_SOURCE_DIR}/rocblas/include ${ROCM_SOURCE_DIR}/hipsparse/include ) - target_link_libraries(torch_hip PRIVATE fmt::fmt-header-only) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) endif() @@ -1563,20 +1561,24 @@ if(USE_XPU) target_link_libraries( torch_xpu PRIVATE ${Caffe2_XPU_DEPENDENCY_LIBS}) - include(CheckLinkerFlag) - - # Check whether the compiler supports '--no-as-needed' and '--as-needed' - check_linker_flag(CXX "-Wl,--no-as-needed" HAVE_NO_AS_NEEDED) - check_linker_flag(CXX "-Wl,--as-needed" HAVE_AS_NEEDED) - # Ensure that torch_cpu is ready before being linked by torch_xpu. add_dependencies(torch_xpu torch_cpu) - if(HAVE_NO_AS_NEEDED AND HAVE_AS_NEEDED) - target_link_libraries(torch_xpu PRIVATE - "-Wl,--no-as-needed,\"$\" -Wl,--as-needed") + if(MSVC) + target_link_libraries(torch_xpu PUBLIC torch_cpu_library) else() - target_link_libraries(torch_xpu PRIVATE "$") + include(CheckLinkerFlag) + + # Check whether the compiler supports '--no-as-needed' and '--as-needed' + check_linker_flag(CXX "-Wl,--no-as-needed" HAVE_NO_AS_NEEDED) + check_linker_flag(CXX "-Wl,--as-needed" HAVE_AS_NEEDED) + + if(HAVE_NO_AS_NEEDED AND HAVE_AS_NEEDED) + target_link_libraries(torch_xpu PRIVATE + "-Wl,--no-as-needed,\"$\" -Wl,--as-needed") + else() + target_link_libraries(torch_xpu PRIVATE "$") + endif() endif() endif() diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index c9b91dc31b880f..5fcf71016aea69 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -6,6 +6,7 @@ #include #include "caffe2/perfkernels/common.h" +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-prototypes") namespace caffe2 { /** @@ -125,7 +126,7 @@ static bool EmbeddingLookupGenericSlowIdx( const float* scale_bias, \ bool normalize_by_lengths, \ OutType* out) { \ - if (std::is_same::value) { \ + if constexpr (std::is_same::value) { \ CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \ } else { \ CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ @@ -231,3 +232,4 @@ EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true); #undef EMBEDDING_IDX_SPECIALIZATION } // namespace caffe2 +C10_DIAGNOSTIC_POP() diff --git a/caffe2/serialize/crc.cc b/caffe2/serialize/crc.cc index 7a7173e417fd3b..944d95c2ec3565 100644 --- a/caffe2/serialize/crc.cc +++ b/caffe2/serialize/crc.cc @@ -1,5 +1,4 @@ #include "miniz.h" -#include #include "caffe2/serialize/crc_alt.h" diff --git a/caffe2/utils/proto_wrap.h b/caffe2/utils/proto_wrap.h index 75e2c4be5c190c..4ce0ce5d3b8bfb 100644 --- a/caffe2/utils/proto_wrap.h +++ b/caffe2/utils/proto_wrap.h @@ -9,6 +9,29 @@ namespace caffe2 { // testing and valgrind cases to avoid protobuf appearing to "leak" memory). TORCH_API void ShutdownProtobufLibrary(); +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); } // namespace caffe2 +namespace ONNX_NAMESPACE { + +// ONNX wrapper functions for protobuf's GetEmptyStringAlreadyInited() function +// used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +} // namespace ONNX_NAMESPACE + +namespace torch { + +// Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() +// function used to avoid duplicated global variable in the case when protobuf +// is built with hidden visibility. +TORCH_API const ::std::string& GetEmptyStringAlreadyInited(); + +void ShutdownProtobufLibrary(); + +} // namespace torch #endif // CAFFE2_UTILS_PROTO_WRAP_H_ diff --git a/caffe2/utils/string_utils.cc b/caffe2/utils/string_utils.cc index 640cac70edbbad..ba763738f7bb04 100644 --- a/caffe2/utils/string_utils.cc +++ b/caffe2/utils/string_utils.cc @@ -3,6 +3,7 @@ #include #include #include +#include namespace caffe2 { diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index d221f5844f30d4..5e383d9715298d 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -322,6 +322,18 @@ if(INTERN_BUILD_ATEN_OPS) LIST(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} ${CXX_ZVECTOR_FLAGS}") endif(CXX_ZVECTOR_FOUND) + if(CXX_SVE_FOUND) + if(CXX_SVE256_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION") + list(APPEND CPU_CAPABILITY_NAMES "SVE256") + if("${CMAKE_C_COMPILER_ID}" MATCHES "Clang") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -O2 -march=armv8.2-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") + else() + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=armv8.2-a+sve -DCPU_CAPABILITY_SVE -msve-vector-bits=256") + endif() + endif(CXX_SVE256_FOUND) + endif(CXX_SVE_FOUND) + list(LENGTH CPU_CAPABILITY_NAMES NUM_CPU_CAPABILITY_NAMES) math(EXPR NUM_CPU_CAPABILITY_NAMES "${NUM_CPU_CAPABILITY_NAMES}-1") diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 8abea841fcf61c..3e59b813d31381 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -544,6 +544,11 @@ if(USE_XNNPACK AND NOT USE_SYSTEM_XNNPACK) # Disable I8MM For CI since clang 9 does not support neon i8mm. set(XNNPACK_ENABLE_ARM_I8MM OFF CACHE BOOL "") + # Older MSVC versions don't support AVX512FP. TODO Minimum version support? + IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") + set(XNNPACK_ENABLE_AVX512FP16 OFF CACHE BOOL "") + ENDIF() + # Conditionally disable AVX512AMX, as it requires Clang 11 or later. Note that # XNNPACK does conditionally compile this based on GCC version. Once it also does # so based on Clang version, this logic can be removed. @@ -1103,10 +1108,6 @@ if(USE_ROCM) message(STATUS "Disabling Kernel Assert for ROCm") endif() - include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) - if(USE_CUDA) - caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) - endif() else() caffe2_update_option(USE_ROCM OFF) endif() @@ -1580,7 +1581,7 @@ if(USE_KINETO) message(STATUS "Using Kineto with Roctracer support") endif() - if(NOT USE_XPU) + if((NOT USE_XPU) OR WIN32) set(LIBKINETO_NOXPUPTI ON CACHE STRING "" FORCE) else() set(LIBKINETO_NOXPUPTI OFF CACHE STRING "") diff --git a/cmake/Modules/FindARM.cmake b/cmake/Modules/FindARM.cmake index ffd588cf2737df..340c2a64e0356e 100644 --- a/cmake/Modules/FindARM.cmake +++ b/cmake/Modules/FindARM.cmake @@ -5,7 +5,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") EXECUTE_PROCESS(COMMAND cat /proc/cpuinfo OUTPUT_VARIABLE CPUINFO) #neon instruction can be found on the majority part of modern ARM processor - STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE "${CPUINFO}") STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) IF (NEON_TRUE) set(NEON_FOUND true CACHE BOOL "NEON available on host") @@ -14,7 +14,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (NEON_TRUE) # on ARMv8, neon is inherit and instead listed as 'asimd' in /proc/cpuinfo - STRING(REGEX REPLACE "^.*(asimd).*$" "\\1" ASIMD_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(asimd).*$" "\\1" ASIMD_THERE "${CPUINFO}") STRING(COMPARE EQUAL "asimd" "${ASIMD_THERE}" ASIMD_TRUE) IF (ASIMD_TRUE) set(ASIMD_FOUND true CACHE BOOL "ASIMD/NEON available on host") @@ -22,8 +22,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") set(ASIMD_FOUND false CACHE BOOL "ASIMD/NEON available on host") ENDIF (ASIMD_TRUE) + #sve instruction can be found on the majority part of modern ARM processor + STRING(REGEX REPLACE "^.*(sve).*$" "\\1" SVE_THERE ${CPUINFO}) + STRING(COMPARE EQUAL "sve" "${SVE_THERE}" SVE_TRUE) + IF (SVE_TRUE) + set(SVE_FOUND true CACHE BOOL "SVE available on host") + ELSE (SVE_TRUE) + set(SVE_FOUND false CACHE BOOL "SVE available on host") + ENDIF (SVE_TRUE) + #Find the processor type (for now OMAP3 or OMAP4) - STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(OMAP3).*$" "\\1" OMAP3_THERE "${CPUINFO}") STRING(COMPARE EQUAL "OMAP3" "${OMAP3_THERE}" OMAP3_TRUE) IF (OMAP3_TRUE) set(CORTEXA8_FOUND true CACHE BOOL "OMAP3 available on host") @@ -32,7 +41,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux") ENDIF (OMAP3_TRUE) #Find the processor type (for now OMAP3 or OMAP4) - STRING(REGEX REPLACE "^.*(OMAP4).*$" "\\1" OMAP4_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(OMAP4).*$" "\\1" OMAP4_THERE "${CPUINFO}") STRING(COMPARE EQUAL "OMAP4" "${OMAP4_THERE}" OMAP4_TRUE) IF (OMAP4_TRUE) set(CORTEXA9_FOUND true CACHE BOOL "OMAP4 available on host") @@ -49,7 +58,7 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin") IF(NOT CPUINFO STREQUAL "") #neon instruction can be found on the majority part of modern ARM processor - STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE ${CPUINFO}) + STRING(REGEX REPLACE "^.*(neon).*$" "\\1" NEON_THERE "${CPUINFO}") STRING(COMPARE EQUAL "neon" "${NEON_THERE}" NEON_TRUE) IF (NEON_TRUE) set(NEON_FOUND true CACHE BOOL "NEON available on host") @@ -79,3 +88,72 @@ if(NOT CORTEXA9_FOUND) MESSAGE(STATUS "No OMAP4 processor on this machine.") endif(NOT CORTEXA9_FOUND) mark_as_advanced(NEON_FOUND) + +#SVE support is availale is only for Linux OS. +IF(CMAKE_SYSTEM_NAME MATCHES "Linux") + # Include necessary modules for checking C and C++ source compilations + INCLUDE(CheckCSourceCompiles) + INCLUDE(CheckCXXSourceCompiles) + + # Test code for SVE support + SET(SVE_CODE " + #include + int main() + { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; + } + ") + + # Macro to check for SVE instruction support + MACRO(CHECK_SVE lang type flags) + # Save the current state of required flags + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + + # Set the flags necessary for compiling the test code with SVE support + SET(CMAKE_REQUIRED_FLAGS "${CMAKE_${lang}_FLAGS_INIT} ${flags}") + + # Check if the source code compiles with the given flags for the specified language (C or C++) + IF(lang STREQUAL "CXX") + CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type}) + ELSE() + CHECK_C_SOURCE_COMPILES("${SVE_CODE}" ${lang}_HAS_${type}) + ENDIF() + + # If the compilation test is successful, set appropriate variables indicating support + IF(${lang}_HAS_${type}) + set(${lang}_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support") + SET(${lang}_${type}_FLAGS "${flags}" CACHE STRING "${lang} ${type} flags") + ENDIF() + + # Restore the original state of required flags + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + # If the compilation test fails, indicate that the support is not found + IF(NOT ${lang}_${type}_FOUND) + SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support") + SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags") + ENDIF() + + # Mark the variables as advanced to hide them in the default CMake GUI + MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS) + ENDMACRO() + + # Check for SVE256 vector length + CHECK_SVE(CXX "SVE256" "-march=armv8-a+sve -msve-vector-bits=256") + + # If SVE256 support is not found, set CXX_SVE_FOUND to FALSE and notify the user + if(NOT CXX_SVE256_FOUND) + set(CXX_SVE_FOUND FALSE CACHE BOOL "SVE not available on host") + message(STATUS "No SVE processor on this machine.") + else() + # If SVE256 support is found, set CXX_SVE_FOUND to TRUE and notify the user + set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + message(STATUS "SVE support detected.") + endif() + + # Mark the SVE support variable as advanced + mark_as_advanced(CXX_SVE_FOUND) +ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux") diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index b69378b6b9fb96..234d361d7f5c27 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -43,7 +43,9 @@ IF(NOT MKLDNN_FOUND) endif() endif() if(LINUX) - set(ABI_NEUTRAL_FLAGS -fpreview-breaking-changes) + set(DNNL_CXX_FLAGS "-DCMAKE_CXX_FLAGS=-fpreview-breaking-changes") + else() + set(DNNL_CXX_FLAGS "") endif() ExternalProject_Add(xpu_mkldnn_proj SOURCE_DIR ${MKLDNN_ROOT} @@ -51,7 +53,7 @@ IF(NOT MKLDNN_FOUND) BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=${SYCL_CXX_DRIVER} - -DCMAKE_CXX_FLAGS=${ABI_NEUTRAL_FLAGS} + ${DNNL_CXX_FLAGS} -DDNNL_GPU_RUNTIME=SYCL -DDNNL_CPU_RUNTIME=THREADPOOL -DDNNL_BUILD_TESTS=OFF @@ -85,13 +87,18 @@ IF(NOT MKLDNN_FOUND) SET(ONEDNN_BUILD_GRAPH ON CACHE BOOL "" FORCE) ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) + IF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") + MESSAGE("-- Will build oneDNN UKERNEL") + SET(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + ENDIF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") + FIND_PACKAGE(BLAS) FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include) - FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl) IF(NOT MKLDNN_INCLUDE_DIR) MESSAGE("MKLDNN_INCLUDE_DIR not found") EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT}) - FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) ENDIF(NOT MKLDNN_INCLUDE_DIR) IF(BUILD_ONEDNN_GRAPH) FIND_PATH(LLGA_INCLUDE_DIR dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include/oneapi/dnnl) diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index 5b249e5e444fcf..ec46a111eaed9f 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -49,12 +49,17 @@ find_file( ) # Find SYCL library fullname. -find_library( - SYCL_LIBRARY - NAMES sycl-preview - HINTS ${SYCL_LIBRARY_DIR} - NO_DEFAULT_PATH -) +# Don't use if(LINUX) here since this requires cmake>=3.25 and file is installed +# and used by other projects. +# See: https://cmake.org/cmake/help/v3.25/variable/LINUX.html +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + find_library( + SYCL_LIBRARY + NAMES sycl-preview + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH + ) +endif() # On Windows, currently there's no sycl.lib. Only sycl7.lib with version suffix, # where the current version of the SYCL runtime is 7. # Until oneAPI adds support to sycl.lib without the version suffix, diff --git a/docs/source/amp.rst b/docs/source/amp.rst index 7192a47fb24b60..8698742a9367d0 100644 --- a/docs/source/amp.rst +++ b/docs/source/amp.rst @@ -95,6 +95,11 @@ updates the parameters, so the scale factor does not interfere with the learning .. currentmodule:: torch.cuda.amp +.. autoclass:: GradScaler + :members: + +.. currentmodule:: torch.cpu.amp + .. autoclass:: GradScaler :members: @@ -365,7 +370,7 @@ in which unlisted ops run if they're downstream from autocasted ops. If an op is unlisted, we assume it's numerically stable in ``bfloat16``. If you believe an unlisted op is numerically unstable in ``bfloat16``, -please file an issue. +please file an issue. ``float16`` shares the lists of ``bfloat16``. CPU Ops that can autocast to ``bfloat16`` """"""""""""""""""""""""""""""""""""""""" @@ -375,19 +380,25 @@ CPU Ops that can autocast to ``bfloat16`` ``conv3d``, ``bmm``, ``mm``, +``linalg_vecdot``, ``baddbmm``, ``addmm``, ``addbmm``, ``linear``, ``matmul``, -``_convolution`` +``_convolution``, +``conv_tbc``, +``mkldnn_rnn_layer``, +``conv_transpose1d``, +``conv_transpose2d``, +``conv_transpose3d``, +``prelu``, +``scaled_dot_product_attention``, +``_native_multi_head_attention`` CPU Ops that can autocast to ``float32`` """""""""""""""""""""""""""""""""""""""" -``conv_transpose1d``, -``conv_transpose2d``, -``conv_transpose3d``, ``avg_pool3d``, ``binary_cross_entropy``, ``grid_sampler``, @@ -421,9 +432,22 @@ CPU Ops that can autocast to ``float32`` ``replication_pad2d``, ``replication_pad3d``, ``mse_loss``, +``cosine_embedding_loss``, +``nll_loss``, +``nll_loss2d``, +``hinge_embedding_loss``, +``poisson_nll_loss``, +``cross_entropy_loss``, +``l1_loss``, +``huber_loss``, +``margin_ranking_loss``, +``soft_margin_loss``, +``triplet_margin_loss``, +``multi_margin_loss``, ``ctc_loss``, ``kl_div``, ``multilabel_margin_loss``, +``binary_cross_entropy_with_logits``, ``fft_fft``, ``fft_ifft``, ``fft_fft2``, @@ -438,7 +462,6 @@ CPU Ops that can autocast to ``float32`` ``fft_irfftn``, ``fft_hfft``, ``fft_ihfft``, -``linalg_matrix_norm``, ``linalg_cond``, ``linalg_matrix_rank``, ``linalg_solve``, @@ -451,14 +474,10 @@ CPU Ops that can autocast to ``float32`` ``linalg_tensorinv``, ``linalg_tensorsolve``, ``fake_quantize_per_tensor_affine``, -``eig``, ``geqrf``, -``lstsq``, ``_lu_with_info``, ``qr``, -``solve``, ``svd``, -``symeig``, ``triangular_solve``, ``fractional_max_pool2d``, ``fractional_max_pool3d``, diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst index 328c3cecead6e1..5bdc4e81d352cb 100644 --- a/docs/source/cuda.rst +++ b/docs/source/cuda.rst @@ -122,6 +122,9 @@ Memory management change_current_allocator MemPool MemPoolContext + +.. autoclass:: torch.cuda.use_mem_pool + .. FIXME The following doesn't seem to exist. Is it supposed to? https://github.com/pytorch/pytorch/issues/27785 .. autofunction:: reset_max_memory_reserved diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index f4c73b9381e594..b0661f867c961c 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -876,7 +876,6 @@ If you are running single node training, it may be convenient to interactively b .. py:module:: torch.distributed.nn.api .. py:module:: torch.distributed.nn.jit .. py:module:: torch.distributed.nn.jit.templates -.. py:module:: torch.distributed.tensor .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks .. py:module:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst index de7525b32e2e2d..694212296e35b0 100644 --- a/docs/source/distributed.tensor.parallel.rst +++ b/docs/source/distributed.tensor.parallel.rst @@ -5,7 +5,7 @@ Tensor Parallelism - torch.distributed.tensor.parallel ====================================================== Tensor Parallelism(TP) is built on top of the PyTorch DistributedTensor -(`DTensor `__) +(`DTensor `__) and provides different parallelism styles: Colwise, Rowwise, and Sequence Parallelism. .. warning :: diff --git a/docs/source/distributed.tensor.rst b/docs/source/distributed.tensor.rst new file mode 100644 index 00000000000000..1df4f2e43d5d9a --- /dev/null +++ b/docs/source/distributed.tensor.rst @@ -0,0 +1,191 @@ +.. currentmodule:: torch.distributed.tensor + +torch.distributed.tensor +=========================== + +.. note:: + ``torch.distributed.tensor`` is currently in alpha state and under + development, we are committing backward compatibility for the most APIs listed + in the doc, but there might be API changes if necessary. + + +PyTorch DTensor (Distributed Tensor) +--------------------------------------- + +PyTorch DTensor offers simple and flexible tensor sharding primitives that transparently handles distributed +logic, including sharded storage, operator computation and collective communications across devices/hosts. +``DTensor`` could be used to build different paralleism solutions and support sharded state_dict representation +when working with multi-dimensional sharding. + +Please see examples from the PyTorch native parallelism solutions that are built on top of ``DTensor``: + +* `Tensor Parallel `__ +* `FSDP2 `__ + +.. automodule:: torch.distributed.tensor + +:class:`DTensor` follows the SPMD (single program, multiple data) programming model to empower users to +write distributed program as if it's a **single-device program with the same convergence property**. It +provides a uniform tensor sharding layout (DTensor Layout) through specifying the :class:`DeviceMesh` +and :class:`Placement`: + +- :class:`DeviceMesh` represents the device topology and the communicators of the cluster using + an n-dimensional array. + +- :class:`Placement` describes the sharding layout of the logical tensor on the :class:`DeviceMesh`. + DTensor supports three types of placements: :class:`Shard`, :class:`Replicate` and :class:`Partial`. + + +DTensor Class APIs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. currentmodule:: torch.distributed.tensor + +:class:`DTensor` is a ``torch.Tensor`` subclass. This means once a :class:`DTensor` is created, it could be +used in very similar way to ``torch.Tensor``, including running different types of PyTorch operators as if +running them in a single device, allowing proper distributed computation for PyTorch operators. + +In addition to existing ``torch.Tensor`` methods, it also offers a set of additional methods to interact with +``torch.Tensor``, ``redistribute`` the DTensor Layout to a new DTensor, get the full tensor content +on all devices, etc. + +.. autoclass:: DTensor + :members: + :member-order: bysource + + +DeviceMesh as the distributed communicator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. currentmodule:: torch.distributed.device_mesh + +:class:`DeviceMesh` was built from DTensor as the abstraction to describe cluster's device topology and represent +multi-dimensional communicators (on top of ``ProcessGroup``). To see the details of how to create/use a DeviceMesh, +please refer to the `DeviceMesh recipe `__. + + +DTensor Placement Types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. automodule:: torch.distributed.tensor.placement_types +.. currentmodule:: torch.distributed.tensor.placement_types + +DTensor supports the following types of :class:`Placement` on each :class:`DeviceMesh` dimension: + +.. autoclass:: Shard + :members: + :undoc-members: + +.. autoclass:: Replicate + :members: + :undoc-members: + +.. autoclass:: Partial + :members: + :undoc-members: + +.. autoclass:: Placement + :members: + :undoc-members: + + +Different ways to create a DTensor +--------------------------------------- + +.. currentmodule:: torch.distributed.tensor + +There're three ways to construct a :class:`DTensor`: + * :meth:`distribute_tensor` creates a :class:`DTensor` from a logical or "global" ``torch.Tensor`` on + each rank. This could be used to shard the leaf ``torch.Tensor`` s (i.e. model parameters/buffers + and inputs). + * :meth:`DTensor.from_local` creates a :class:`DTensor` from a local ``torch.Tensor`` on each rank, which can + be used to create :class:`DTensor` from a non-leaf ``torch.Tensor`` s (i.e. intermediate activation + tensors during forward/backward). + * DTensor provides dedicated tensor factory functions (e.g. :meth:`empty`, :meth:`ones`, :meth:`randn`, etc.) + to allow different :class:`DTensor` creations by directly specifying the :class:`DeviceMesh` and + :class:`Placement`. Compare to :meth:`distribute_tensor`, this could directly materializing the sharded memory + on device, instead of performing sharding after initializing the logical Tensor memory. + +Create DTensor from a logical torch.Tensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The SPMD (single program, multiple data) programming model in ``torch.distributed`` launches multiple processes +(i.e. via ``torchrun``) to execute the same program, this means that the model inside the program would be +initialized on different processes first (i.e. the model might be initialized on CPU, or meta device, or directly +on GPU if enough memory). + +``DTensor`` offers a :meth:`distribute_tensor` API that could shard the model weights or Tensors to ``DTensor`` s, +where it would create a DTensor from the "logical" Tensor on each process. This would empower the created +``DTensor`` s to comply with the single device semantic, which is critical for **numerical correctness**. + +.. autofunction:: distribute_tensor + +Along with :meth:`distribute_tensor`, DTensor also offers a :meth:`distribute_module` API to allow easier +sharding on the :class:`nn.Module` level + +.. autofunction:: distribute_module + + +DTensor Factory Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +DTensor also provides dedicated tensor factory functions to allow creating :class:`DTensor` directly +using torch.Tensor like factory function APIs (i.e. torch.ones, torch.empty, etc), by additionally +specifying the :class:`DeviceMesh` and :class:`Placement` for the :class:`DTensor` created: + +.. autofunction:: zeros + +.. autofunction:: ones + +.. autofunction:: empty + +.. autofunction:: full + +.. autofunction:: rand + +.. autofunction:: randn + + +Debugging +--------------------------------------- + +.. automodule:: torch.distributed.tensor.debug +.. currentmodule:: torch.distributed.tensor.debug + +Logging +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +When launching the program, you can turn on additional logging using the `TORCH_LOGS` environment variable from +`torch._logging `__ : + +* `TORCH_LOGS=+dtensor` will display `logging.DEBUG` messages and all levels above it. +* `TORCH_LOGS=dtensor` will display `logging.INFO` messages and above. +* `TORCH_LOGS=-dtensor` will display `logging.WARNING` messages and above. + +Debugging Tools +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To debug the program that applied DTensor, and understand more details about what collectives happened under the +hood, DTensor provides a :class:`CommDebugMode`: + +.. autoclass:: CommDebugMode + :members: + :undoc-members: + +To visualize the sharding of a DTensor that have less than 3 dimensions, DTensor provides :meth:`visualize_sharding`: + +.. autofunction:: visualize_sharding + + +Experimental Features +--------------------------------------- + +``DTensor`` also provides a set of experimental features. These features are either in prototyping stage, or the basic +functionality is done and but looking for user feedbacks. Please submit a issue to PyTorch if you have feedbacks to +these features. + +.. automodule:: torch.distributed.tensor.experimental +.. currentmodule:: torch.distributed.tensor.experimental + +.. autofunction:: local_map +.. autofunction:: register_sharding + + +.. modules that are missing docs, add the doc later when necessary +.. py:module:: torch.distributed.tensor.device_mesh diff --git a/docs/source/export.rst b/docs/source/export.rst index deb84548a19cdc..603594847f061c 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -291,6 +291,178 @@ torch.export>`), but seeing as the context manager does not affect the tensor computations in the model, we can go with the non-strict mode's result. +.. _Training Export: + +Export for Training and Inference +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In PyTorch 2.5, we introduced a new API called :func:`export_for_training`. +It's still going through hardening, so if you run into any issues, please file +them to Github with the "oncall: export" tag. + +In this API, we produce the most generic IR that contains all ATen operators +(including both functional and non-functional) which can be used to train in +eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization +and will soon be the default IR of torch.export.export. To read further about +the motivation behind this change, please refer to +https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 + +When this API is combined with :func:`run_decompositions()`, you should be able to get inference IR with +any desired decomposition behavior. + +To show some examples: + +:: + + class ConvBatchnorm(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) + + mod = ConvBatchnorm() + inp = torch.randn(1, 1, 3, 3) + + ep_for_training = torch.export.export_for_training(mod, (inp,)) + print(ep_for_training) + +.. code-block:: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None + batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + return (batch_norm,) + + Graph signature: + ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='batch_norm'), target=None) + ] + ) + Range constraints: {} + + +From the above output, you can see that :func:`export_for_training` produces pretty much the same ExportedProgram +as :func:`export` except for the operators in the graph. You can see that we captured batch_norm in the most general +form. This op is non-functional and will be lowered to different ops when running inference. + +You can also go from this IR to an inference IR via :func:`run_decompositions` with arbitrary customizations. + +:: + + # Lower to core aten inference IR, but keep conv2d + decomp_table = torch.export.core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + ep_for_inference = ep_for_training.run_decompositions(decomp_table) + + print(ep_for_inference) + +.. code-block:: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + return (getitem_3, getitem_4, add, getitem) + + Graph signature: ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'), + OutputSpec(kind=, arg=TensorArgument(name='getitem_4'), target='bn.running_var'), + OutputSpec(kind=, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'), + OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None) + ] + ) + Range constraints: {} + +Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR +containing core aten operators except for `conv2d`. + +You can do even more customization by directly registering your chosen decomposition behaviors. + +You can do even more customizations by directly registering custom decomp behaviour + +:: + + # Lower to core aten inference IR, but customize conv2d + decomp_table = torch.export.core_aten_decompositions() + + def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + + decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function + ep_for_inference = ep_for_training.run_decompositions(decomp_table) + + print(ep_for_inference) + +.. code-block:: + + ExportedProgram: + class GraphModule(torch.nn.Module): + def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): + convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None + mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None + add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None + _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None + getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] + getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] + getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None + return (getitem_3, getitem_4, add, getitem) + + Graph signature: ExportGraphSignature( + input_specs=[ + InputSpec(kind=, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_weight'), target='bn.weight', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='p_bn_bias'), target='bn.bias', persistent=None), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_mean'), target='bn.running_mean', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_running_var'), target='bn.running_var', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='b_bn_num_batches_tracked'), target='bn.num_batches_tracked', persistent=True), + InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None) + ], + output_specs=[ + OutputSpec(kind=, arg=TensorArgument(name='getitem_3'), target='bn.running_mean'), + OutputSpec(kind=, arg=TensorArgument(name='getitem_4'), target='bn.running_var'), + OutputSpec(kind=, arg=TensorArgument(name='add'), target='bn.num_batches_tracked'), + OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None) + ] + ) + Range constraints: {} + + Expressing Dynamism ^^^^^^^^^^^^^^^^^^^ @@ -676,8 +848,8 @@ API Reference .. autofunction:: save .. autofunction:: load .. autofunction:: register_dataclass -.. autoclass:: torch.export.dynamic_shapes.DIM .. autofunction:: torch.export.dynamic_shapes.Dim +.. autofunction:: torch.export.exported_program.core_aten_decompositions .. autofunction:: dims .. autoclass:: torch.export.dynamic_shapes.ShapesCollection diff --git a/docs/source/index.rst b/docs/source/index.rst index dcaadcbb63edc1..773e64204293b7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -74,12 +74,13 @@ Features described in this documentation are classified by release status: torch.backends torch.export torch.distributed + torch.distributed.tensor torch.distributed.algorithms.join torch.distributed.elastic torch.distributed.fsdp + torch.distributed.tensor.parallel torch.distributed.optim torch.distributed.pipelining - torch.distributed.tensor.parallel torch.distributed.checkpoint torch.distributions torch.compiler diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 60dd67f643b890..8177b91a9c15c6 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -283,9 +283,11 @@ The following ops are currently supported: kron meshgrid narrow + nn.functional.unfold ravel select split + stack t transpose vsplit @@ -294,6 +296,7 @@ The following ops are currently supported: Tensor.expand_as Tensor.reshape Tensor.reshape_as + Tensor.unfold Tensor.view .. This module needs to be documented. Adding here in the meantime diff --git a/docs/source/mobile_optimizer.rst b/docs/source/mobile_optimizer.rst index 2624a5cb0bed9a..88eb87fb92f4d3 100644 --- a/docs/source/mobile_optimizer.rst +++ b/docs/source/mobile_optimizer.rst @@ -2,7 +2,11 @@ torch.utils.mobile_optimizer =================================== .. warning:: - This API is in beta and may change in the near future. + PyTorch Mobile is no longer actively supported. Please check out + `ExecuTorch `__, PyTorch's + all-new on-device inference library. You can also review + documentation on `XNNPACK `__ + and `Vulkan `__ delegates. Torch mobile supports ``torch.utils.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode. The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set, a preserved method list, and a backend. diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index 4b0e1a52269afc..548c6463d08c83 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -19,6 +19,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined is_available is_initialized memory_stats + get_device_capability set_device set_stream stream diff --git a/docs/source/notes/amp_examples.rst b/docs/source/notes/amp_examples.rst index f95f99b7ac2fa4..1ae63c4396cc6a 100644 --- a/docs/source/notes/amp_examples.rst +++ b/docs/source/notes/amp_examples.rst @@ -9,7 +9,7 @@ Ordinarily, "automatic mixed precision training" means training with :class:`torch.autocast` and :class:`torch.amp.GradScaler` together. Instances of :class:`torch.autocast` enable autocasting for chosen regions. -Autocasting automatically chooses the precision for GPU operations to improve performance +Autocasting automatically chooses the precision for operations to improve performance while maintaining accuracy. Instances of :class:`torch.amp.GradScaler` help perform the steps of diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 7d434bbbba64ce..74d0c89387fa7e 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -471,6 +471,13 @@ Available options: set the knob value to: [256:1,512:2,1024:4,>:8]. ``roundup_power2_divisions`` is only meaningful with ``backend:native``. With ``backend:cudaMallocAsync``, ``roundup_power2_divisions`` is ignored. +* ``max_non_split_rounding_mb`` will allow non-split blocks for better reuse, eg, + a 1024MB cached block can be re-used for a 512MB allocation request. In the default + case, we only allow up to 20MB of rounding of non-split blocks, so a 512MB block + can only be served with between 512-532 MB size block. If we set the value of this + option to 1024, it will alow 512-1536 MB size blocks to be used for a 512MB block + which increases reuse of larger blocks. This will also help in reducing the stalls + in avoiding expensive cudaMalloc calls. * ``garbage_collection_threshold`` helps actively reclaiming unused GPU memory to avoid triggering expensive sync-and-reclaim-all operation (release_cached_blocks), which can be unfavorable to latency-critical GPU applications (e.g., servers). @@ -527,6 +534,10 @@ Available options: allocation time of pinned memory. A good value for this option is 8 based on benchmarking results. + `pinned_use_background_threads` option is a boolean flag to enable background thread + for processing events. This avoids any slow path associated with querying/processing of + events in the fast allocation path. This feature is disabled by default. + .. note:: Some stats reported by the diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 5541d28bdcafa0..c05dc028a471c9 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,3 +398,4 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals +.. autoclass:: skip_data diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index ffaa8ef836b787..cb795cfb11f91a 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -13,7 +13,35 @@ The exported model can be consumed by any of the many `runtimes that support ONNX `_, including Microsoft's `ONNX Runtime `_. -**There are two flavors of ONNX exporter API that you can use, as listed below:** +**There are two flavors of ONNX exporter API that you can use, as listed below.** +Both can be called through function :func:`torch.onnx.export`. +Next example shows how to export a simple model. + +.. code-block:: python + + import torch + + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 128, 5) + + def forward(self, x): + return torch.relu(self.conv1(x)) + + input_tensor = torch.rand((1, 1, 128, 128), dtype=torch.float32) + + model = MyModel() + + torch.onnx.export( + model, # model to export + (input_tensor,), # inputs of the model, + "my_model.onnx", # filename of the ONNX model + input_names=["input"], # Rename inputs for the ONNX model + dynamo=True # True or False to select the exporter to use + ) + +Next sections introduces the two versions of the exporter. TorchDynamo-based ONNX Exporter ------------------------------- diff --git a/docs/source/onnx_dynamo.rst b/docs/source/onnx_dynamo.rst index be6b9d48a8d399..9865844f32e6b2 100644 --- a/docs/source/onnx_dynamo.rst +++ b/docs/source/onnx_dynamo.rst @@ -28,7 +28,6 @@ The exporter is designed to be modular and extensible. It is composed of the fol - **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model. - **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models. - **ONNX Program**: :class:`ONNXProgram` is the output of the exporter that contains the exported ONNX graph and diagnostics. - - **ONNX Program Serializer**: :class:`ONNXProgramSerializer` serializes the exported model to a file. - **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter. Dependencies @@ -45,6 +44,9 @@ They can be installed through `pip `_: pip install --upgrade onnx onnxscript +`onnxruntime `_ can then be used to execute the model +on a large variety of processors. + A simple example ---------------- @@ -75,9 +77,9 @@ See below a demonstration of exporter API in action with a simple Multilayer Per model = MLPModel() tensor_x = torch.rand((97, 8), dtype=torch.float32) - onnx_program = torch.onnx.dynamo_export(model, tensor_x) + onnx_program = torch.onnx.export(model, (tensor_x,), dynamo=True) -As the code above shows, all you need is to provide :func:`torch.onnx.dynamo_export` with an instance of the model and its input. +As the code above shows, all you need is to provide :func:`torch.onnx.export` with an instance of the model and its input. The exporter will then return an instance of :class:`torch.onnx.ONNXProgram` that contains the exported ONNX graph along with extra information. The in-memory model available through ``onnx_program.model_proto`` is an ``onnx.ModelProto`` object in compliance with the `ONNX IR spec `_. @@ -87,6 +89,17 @@ The ONNX model may then be serialized into a `Protobuf file `__ to help users debug and improve their model using a GUI, such as @@ -144,15 +162,9 @@ API Reference .. autoclass:: torch.onnx.ONNXProgram :members: -.. autoclass:: torch.onnx.ONNXProgramSerializer - :members: - .. autoclass:: torch.onnx.ONNXRuntimeOptions :members: -.. autoclass:: torch.onnx.InvalidExportOptionsError - :members: - .. autoclass:: torch.onnx.OnnxExporterError :members: diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst index 2009aea813d038..8c8032bd26b4da 100644 --- a/docs/source/onnx_torchscript.rst +++ b/docs/source/onnx_torchscript.rst @@ -702,8 +702,6 @@ Functions .. autofunction:: unregister_custom_op_symbolic .. autofunction:: select_model_mode_for_export .. autofunction:: is_in_onnx_export -.. autofunction:: enable_log -.. autofunction:: disable_log .. autofunction:: torch.onnx.verification.find_mismatch Classes diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 44597d867b4970..a15e901bfd54c9 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -154,6 +154,7 @@ PT2 Export (pt2e) Numeric Debugger :template: classtemplate.rst generate_numeric_debug_handle + CUSTOM_KEY NUMERIC_DEBUG_HANDLE_KEY prepare_for_propagation_comparison extract_results_from_loggers diff --git a/docs/source/torch.compiler_faq.rst b/docs/source/torch.compiler_faq.rst index a5883ce015bef2..b7ff8bdd1aab26 100644 --- a/docs/source/torch.compiler_faq.rst +++ b/docs/source/torch.compiler_faq.rst @@ -136,17 +136,6 @@ Why is compilation slow? as long (as many iterations) as you were running when you ran into trouble, and the profiler will accumulate statistics over this duration. -.. code-block:: python - - from torch._dynamo.utils import CompileProfiler - - def my_model(): - ... - - with CompileProfiler() as prof: - profiler_model = torch.compile(my_model, backend=prof) - profiler_model() - print(prof.report()) Why are you recompiling in production? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/torch.compiler_profiling_torch_compile.rst b/docs/source/torch.compiler_profiling_torch_compile.rst index 33a6bca92faf98..0a465f1ca9b764 100644 --- a/docs/source/torch.compiler_profiling_torch_compile.rst +++ b/docs/source/torch.compiler_profiling_torch_compile.rst @@ -6,7 +6,7 @@ What to use torch.profiler for: torch.profiler is helpful for understanding the performance of your program at a kernel-level granularity - for example, it can show graph breaks and GPU utilization at the level of the program. The data provided by the profiler can often help users understand where to investigate further to understand model performance. -To understand kernel-level performance, other toosl exist. NVIDIA's ncu tool can be used, or :ref:`inductor's profiling tools `. +To understand kernel-level performance, other tools exist. NVIDIA's ncu tool can be used, or :ref:`inductor's profiling tools `. See also the `general pytorch profiler guide `_. @@ -66,7 +66,7 @@ Alternatively, turn on *all* flows with the “Flow events” dropdown at the to Working around CUDA Graph profiling issues ------------------------------------------ -When CUDA graphs are enabled, some cuda configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program: +When CUDA graphs are enabled, some CUDA configurations (driver version under 525.85.12 or CUDA < 12) can encounter issues between the profiling tools and CUDA graphs. To fix these issues, add an empty profiling context at the top of your program: .. code-block:: python diff --git a/docs/source/torch.compiler_troubleshooting.rst b/docs/source/torch.compiler_troubleshooting.rst index 7158149c09e190..05560789f07b45 100644 --- a/docs/source/torch.compiler_troubleshooting.rst +++ b/docs/source/torch.compiler_troubleshooting.rst @@ -663,13 +663,6 @@ recompile that function (or part) up to hitting the cache limit, you will first need to determine which guard is failing and what part of your program is triggering it. -The `compile profiler `__ automates the -process of setting TorchDynamo’s cache limit to 1 and running your -program under an observation-only 'compiler' that records the causes of -any guard failures. You should be sure to run your program for at least -as long (as many iterations) as you were running when you ran into -trouble, and the profiler will accumulate statistics over this duration. - If your program exhibits a bounded amount of dynamism, you may be able to tune the TorchDynamo cache limit to allow for each variation to be compiled and cached, but if the cache limit is too high you may find the @@ -685,18 +678,6 @@ support rank-dynamism. In the meantime, setting a specific cache limit can be used in coordination with bucketing techniques to achieve an acceptable number of recompilations for some dynamic models. -.. code-block:: python - - from torch._dynamo.utils import CompileProfiler - - def my_model(): - ... - - with CompileProfiler() as prof: - profiler_model = torch.compile(my_model, backend=prof) - profiler_model() - print(prof.report()) - Accuracy Debugging ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/xpu.rst b/docs/source/xpu.rst index d4085cf4e6267c..a83bea4d1b3f86 100644 --- a/docs/source/xpu.rst +++ b/docs/source/xpu.rst @@ -13,7 +13,6 @@ torch.xpu device device_count device_of - empty_cache get_device_capability get_device_name get_device_properties @@ -51,7 +50,25 @@ Streams and events Stream +Memory management +----------------- +.. autosummary:: + :toctree: generated + :nosignatures: + + empty_cache + max_memory_allocated + max_memory_reserved + memory_allocated + memory_reserved + memory_stats + memory_stats_as_nested_dict + reset_accumulated_memory_stats + reset_peak_memory_stats + + .. This module needs to be documented. Adding here in the meantime .. for tracking purposes +.. py:module:: torch.xpu.memory .. py:module:: torch.xpu.random -.. py:module:: torch.xpu.streams \ No newline at end of file +.. py:module:: torch.xpu.streams diff --git a/pytest.ini b/pytest.ini index d3b9f3a9229843..e2ab2ebd0cc164 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,8 +4,6 @@ addopts = -rEfX # Make tracebacks shorter --tb=native - # Color the output - --color=yes # capture only Python print and C++ py::print, but not C output (low-level Python errors) --capture=sys # don't suppress warnings, but don't shove them all to the end either diff --git a/requirements.txt b/requirements.txt index 8ebae971e488f8..477332375872f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Python dependencies required for development astunparse -expecttest!=0.2.0 +expecttest>=0.2.1 hypothesis numpy psutil diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat index fbdade94353554..60bfebad08c016 100644 --- a/scripts/build_windows.bat +++ b/scripts/build_windows.bat @@ -22,10 +22,6 @@ if NOT DEFINED BUILD_SHARED_LIBS ( ) ) -IF NOT DEFINED BUILDING_WITH_TORCH_LIBS ( - set BUILDING_WITH_TORCH_LIBS=OFF -) - if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( set CAFFE2_STATIC_LINK_CUDA=OFF ) diff --git a/scripts/compile_tests/download_reports.py b/scripts/compile_tests/download_reports.py index c428ddca8a04ba..7ad6521032c9ab 100644 --- a/scripts/compile_tests/download_reports.py +++ b/scripts/compile_tests/download_reports.py @@ -8,10 +8,10 @@ CONFIGS = { - "dynamo38": { - "linux-focal-py3.8-clang10 / test (dynamo, 1, 3, linux.2xlarge)", - "linux-focal-py3.8-clang10 / test (dynamo, 2, 3, linux.2xlarge)", - "linux-focal-py3.8-clang10 / test (dynamo, 3, 3, linux.2xlarge)", + "dynamo39": { + "linux-focal-py3.9-clang10 / test (dynamo, 1, 3, linux.2xlarge)", + "linux-focal-py3.9-clang10 / test (dynamo, 2, 3, linux.2xlarge)", + "linux-focal-py3.9-clang10 / test (dynamo, 3, 3, linux.2xlarge)", }, "dynamo311": { "linux-focal-py3.11-clang10 / test (dynamo, 1, 3, linux.2xlarge)", @@ -26,7 +26,7 @@ } -def download_reports(commit_sha, configs=("dynamo38", "dynamo311", "eager311")): +def download_reports(commit_sha, configs=("dynamo39", "dynamo311", "eager311")): log_dir = "tmp_test_reports_" + commit_sha def subdir_path(config): diff --git a/scripts/compile_tests/update_failures.py b/scripts/compile_tests/update_failures.py index b5511038c29edb..a56e30e99870c6 100755 --- a/scripts/compile_tests/update_failures.py +++ b/scripts/compile_tests/update_failures.py @@ -221,5 +221,5 @@ def read_test_results(directory): args = parser.parse_args() assert Path(args.filename).exists(), args.filename assert Path(args.test_dir).exists(), args.test_dir - dynamo38, dynamo311 = download_reports(args.commit, ("dynamo38", "dynamo311")) - update(args.filename, args.test_dir, dynamo38, dynamo311, args.also_remove_skips) + dynamo39, dynamo311 = download_reports(args.commit, ("dynamo39", "dynamo311")) + update(args.filename, args.test_dir, dynamo39, dynamo311, args.also_remove_skips) diff --git a/scripts/release/apply-release-changes.sh b/scripts/release/apply-release-changes.sh index c1f37a402928bd..49d3d0b5192f76 100755 --- a/scripts/release/apply-release-changes.sh +++ b/scripts/release/apply-release-changes.sh @@ -17,19 +17,25 @@ GIT_TOP_DIR=$(git rev-parse --show-toplevel) RELEASE_VERSION=${RELEASE_VERSION:-$(cut -d'.' -f1-2 "${GIT_TOP_DIR}/version.txt")} DRY_RUN=${DRY_RUN:-enabled} -# Change all GitHub Actions to reference the test-infra release branch -# as opposed to main. echo "Applying to workflows" for i in .github/workflows/*.yml; do sed -i -e s#@main#@"release/${RELEASE_VERSION}"# $i; done -# Change all checkout step in templates to not add ref to checkout echo "Applying to templates" for i in .github/templates/*.yml.j2; do sed -i 's#common.checkout(\(.*\))#common.checkout(\1, checkout_pr_head=False)#' $i; + sed -i -e s#main#"release/${RELEASE_VERSION}"# $i; done +echo "Applying to changes to linux binary builds" +for i in ".github/workflows/_binary-build-linux.yml" ".github/workflows/_binary-test-linux.yml"; do + sed -i "/github.event_name == 'pull_request'/d" $i; + sed -i -e s#main#"release/${RELEASE_VERSION}"# $i; +done + +sed -i -e "/generate_ci_workflows.py/i \\\t\t\t\texport RELEASE_VERSION_TAG=${RELEASE_VERSION}" .github/workflows/lint.yml + # Triton wheel echo "Triton Changes" sed -i -e s#-\ main#"-\ release\/${RELEASE_VERSION}"# .github/workflows/build-triton-wheel.yml diff --git a/scripts/release/tag-docker-images.sh b/scripts/release/tag-docker-images.sh index ab366ecc0e3e2b..f2299d6c463ee2 100644 --- a/scripts/release/tag-docker-images.sh +++ b/scripts/release/tag-docker-images.sh @@ -12,7 +12,7 @@ # git submodule update --init --recursive # # Usage (run from root of project): -# DRY_RUN=disabled ./scripts/release/tag_docker_images.sh +# DRY_RUN=disabled ./scripts/release/tag-docker-images.sh # set -eou pipefail diff --git a/setup.py b/setup.py index ad48f4b0108633..e9f5d2a579432c 100644 --- a/setup.py +++ b/setup.py @@ -1107,6 +1107,12 @@ def make_relative_rpath_args(path): "default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs", ], } + + if cmake_cache_vars["USE_DISTRIBUTED"]: + # Only enable fr_trace command if distributed is enabled + entry_points["console_scripts"].append( + "torchfrtrace = tools.flight_recorder.fr_trace:main", + ) return extensions, cmdclass, packages, entry_points, extra_install_requires @@ -1238,6 +1244,7 @@ def main(): "include/ATen/cpu/vec/vec256/zarch/*.h", "include/ATen/cpu/vec/vec512/*.h", "include/ATen/cpu/vec/*.h", + "include/ATen/cpu/vec/sve/*.h", "include/ATen/core/*.h", "include/ATen/cuda/*.cuh", "include/ATen/cuda/*.h", @@ -1322,6 +1329,7 @@ def main(): "include/torch/csrc/distributed/autograd/rpc_messages/*.h", "include/torch/csrc/dynamo/*.h", "include/torch/csrc/inductor/*.h", + "include/torch/csrc/inductor/aoti_package/*.h", "include/torch/csrc/inductor/aoti_runner/*.h", "include/torch/csrc/inductor/aoti_runtime/*.h", "include/torch/csrc/inductor/aoti_torch/*.h", @@ -1505,7 +1513,7 @@ def main(): f"Programming Language :: Python :: 3.{i}" for i in range(python_min_version[1], version_range_max) ], - license="BSD-3", + license="BSD-3-Clause", keywords="pytorch, machine learning", ) if EMIT_BUILD_WARNING: diff --git a/test/HowToWriteTestsUsingFileCheck.md b/test/HowToWriteTestsUsingFileCheck.md index 0795c23002a162..3a9b28d3574678 100644 --- a/test/HowToWriteTestsUsingFileCheck.md +++ b/test/HowToWriteTestsUsingFileCheck.md @@ -93,7 +93,7 @@ annotations from the example above one would write: * `CHECK-COUNT-EXACTLY-: ` Scans the input and succeeds when a line containing exactly `NUM` entries of `PATTERN` is found. -* `CHECK-DAG: pattern` +* `CHECK-DAG: ` Works similar to the usual `CHECK` pragma, but also matches if there exists a way to reorder the CHECK-DAG pragmas to satisfy all patterns. For example the following pattern: @@ -110,3 +110,18 @@ annotations from the example above one would write: bar end ``` +* `CHECK-SOURCE-HIGHLIGHTED: ` + Check for highlighted source ranges. This is useful when writing tests regarding generated error messages that require source code highlighting. + For example the following pattern: + ``` + # CHECK-SOURCE-HIGHLIGHTED: raise Exception("raised exception + ``` + would match the following input: + ``` + def method_that_raises() -> torch.Tensor: + raise Exception("raised exception") # noqa: TRY002 + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE + builtins.Exception: raised exception + ``` +* `CHECK-REGEX: ` + Scans the input until `PATTERN` is matched, accepts RE syntax for std::regex. diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 05b1b8223cc038..9107da9a37cfe0 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -33,7 +33,8 @@ "torch.nn.quantizable": "torch.ao.nn.quantizable", "torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules", "torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation", - "torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn" + "torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn", + "torch.distributed.tensor.device_mesh": "torch.distributed.device_mesh" }, "torch.backends": [ "contextmanager" diff --git a/test/cpp/aoti_inference/aoti_custom_class.cpp b/test/cpp/aoti_inference/aoti_custom_class.cpp index 5f3ea58620dd22..d3960d9f5074eb 100644 --- a/test/cpp/aoti_inference/aoti_custom_class.cpp +++ b/test/cpp/aoti_inference/aoti_custom_class.cpp @@ -29,12 +29,14 @@ MyAOTIClass::MyAOTIClass( const std::string& model_path, const std::string& device) : lib_path_(model_path), device_(device) { - if (device_ == "cuda") { - runner_ = std::make_unique( - model_path.c_str()); - } else if (device_ == "cpu") { + if (device_ == "cpu") { runner_ = std::make_unique( model_path.c_str()); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device_ == "cuda") { + runner_ = std::make_unique( + model_path.c_str()); +#endif } else { throw std::runtime_error("invalid device: " + device); } diff --git a/test/cpp/aoti_inference/compile_model.py b/test/cpp/aoti_inference/compile_model.py index 4542a9c8f9f80a..0668f9b4b91250 100644 --- a/test/cpp/aoti_inference/compile_model.py +++ b/test/cpp/aoti_inference/compile_model.py @@ -86,7 +86,7 @@ def compile_model(device, data): def main(): data = {} - for device in ["cuda", "cpu"]: + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: compile_model(device, data) torch.jit.script(TensorSerializer(data)).save("script_data.pt") diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index a74109f1a9b9a9..8a97959172c97c 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -35,12 +35,14 @@ void test_aoti(const std::string& device, bool use_runtime_constant_folding) { data_loader.attr(outputs_attr.c_str()).toTensorList().vec(); std::unique_ptr runner; - if (device == "cuda") { - runner = std::make_unique( - model_so_path); - } else if (device == "cpu") { + if (device == "cpu") { runner = std::make_unique( model_so_path); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device == "cuda") { + runner = std::make_unique( + model_so_path); +#endif } else { testing::AssertionFailure() << "unsupported device: " << device; } @@ -111,12 +113,14 @@ void test_aoti_constants_update( real_map.emplace("L__self___w_add", new at::Tensor(add_tensors)); std::unique_ptr runner; - if (device == "cuda") { - runner = std::make_unique( - model_so_path); - } else if (device == "cpu") { + if (device == "cpu") { runner = std::make_unique( model_so_path); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device == "cuda") { + runner = std::make_unique( + model_so_path); +#endif } else { testing::AssertionFailure() << "unsupported device: " << device; } @@ -197,12 +201,14 @@ void test_aoti_double_buffering( real_map.emplace("L__self___w_add", new at::Tensor(add_tensors)); std::unique_ptr runner; - if (device == "cuda") { - runner = std::make_unique( - model_so_path.c_str()); - } else if (device == "cpu") { + if (device == "cpu") { runner = std::make_unique( - model_so_path.c_str()); + model_so_path); +#if defined(USE_CUDA) || defined(USE_ROCM) + } else if (device == "cuda") { + runner = std::make_unique( + model_so_path); +#endif } else { testing::AssertionFailure() << "unsupported device: " << device; } @@ -241,6 +247,7 @@ void test_aoti_double_buffering( ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); } +#if defined(USE_CUDA) || defined(USE_ROCM) void test_aoti_double_buffering_with_tensor_constants() { torch::NoGradGuard no_grad; @@ -279,6 +286,7 @@ void test_aoti_double_buffering_with_tensor_constants() { actual_output_tensors = runner->run(input_tensors); ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); } +#endif } // namespace diff --git a/test/cpp/aoti_inference/test.py b/test/cpp/aoti_inference/test.py index ea3f6f042d3c17..f5e730158ccc00 100644 --- a/test/cpp/aoti_inference/test.py +++ b/test/cpp/aoti_inference/test.py @@ -35,7 +35,7 @@ def forward(self, x, y): # Basice AOTI model test generation. def generate_basic_tests(): - for device in ["cpu", "cuda"]: + for device in ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]: for use_runtime_constant_folding in [True, False]: if device == "cpu" and use_runtime_constant_folding: # We do not test runtime const folding for cpu mode. @@ -74,6 +74,9 @@ def generate_basic_tests(): # AOTI model which will create additional tensors during autograd. def generate_test_with_additional_tensors(): + if not torch.cuda.is_available(): + return + model = NetWithTensorConstants() x = torch.randn((30, 1), device="cuda") y = torch.randn((30, 1), device="cuda") diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index c64adb4ad0521e..0874852517e33b 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -10,14 +10,12 @@ function(c10d_add_test test_src) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) target_link_libraries(${test_name} ${ARGN}) - target_link_libraries(${test_name} fmt::fmt-header-only) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) endfunction() -c10d_add_test(LoggingTest.cpp torch_cpu gtest_main) c10d_add_test(BackoffTest.cpp torch_cpu gtest_main) c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) diff --git a/test/cpp/c10d/LoggingTest.cpp b/test/cpp/c10d/LoggingTest.cpp deleted file mode 100644 index 56e4ad5d3f1df4..00000000000000 --- a/test/cpp/c10d/LoggingTest.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include - -#include -#include - -#include - -TEST(LockGuard, basic) { - std::timed_mutex mutex; - - { - C10D_LOCK_GUARD(lock, mutex); - - // already locked - ASSERT_FALSE(mutex.try_lock()); - } - - ASSERT_TRUE(mutex.try_lock()); - mutex.unlock(); -} - -TEST(LockGuard, logging) { - std::timed_mutex mutex; - - mutex.lock(); - - auto loggingThread = std::async(std::launch::async, [&]() { - std::unique_lock name{mutex, std::defer_lock}; - ::c10d::detail::lockWithLogging( - name, std::chrono::milliseconds(10), "my lock", __FILE__, __LINE__); - }); - - auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(10); - while (true) { - ASSERT_LT(std::chrono::system_clock::now(), deadline); - - testing::internal::CaptureStderr(); - std::this_thread::sleep_for(std::chrono::milliseconds(20)); - std::string output = testing::internal::GetCapturedStderr(); - - if (output.find("my lock: waiting for lock for 10ms") != - std::string::npos) { - break; - } - } - - mutex.unlock(); - - loggingThread.get(); -} diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index 54a929ab982514..d416847f7911a5 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -180,7 +180,7 @@ class ProcessGroupNCCLNoHeartbeatCaught : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), hasMonitorThreadCaughtError_(false) {} - std::timed_mutex& getWatchdogMutex() { + std::mutex& getWatchdogMutex() { return workMetaListMutex_; } @@ -413,7 +413,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { work = pg.allreduce(tensors_); { // Now run all reduce with errors. - std::lock_guard lock(pg.getWatchdogMutex()); + std::lock_guard lock(pg.getWatchdogMutex()); LOG(INFO) << "Lock watchdog thread."; // Wait long enough before monitor thread throws exceptions. std::this_thread::sleep_for( diff --git a/test/cpp/c10d/TCPStoreTest.cpp b/test/cpp/c10d/TCPStoreTest.cpp index c0a7b2b64d601e..7351984f36c997 100644 --- a/test/cpp/c10d/TCPStoreTest.cpp +++ b/test/cpp/c10d/TCPStoreTest.cpp @@ -178,11 +178,12 @@ TEST(TCPStoreTest, testCleanShutdown) { auto serverTCPStore = std::make_unique( "127.0.0.1", - 0, - numWorkers, - true, - std::chrono::seconds(defaultTimeout), - /* wait */ false); + c10d::TCPStoreOptions{ + /* port */ 0, + /* isServer */ true, + numWorkers, + /* waitWorkers */ false, + /* timeout */ std::chrono::seconds(defaultTimeout)}); c10d::test::set(*serverTCPStore, "key", "val"); auto clientTCPStore = c10::make_intrusive( diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index 0b435ddb928b44..f857aecc657cc8 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -586,7 +586,7 @@ struct FooHooksArgs : public at::PrivateUse1HooksArgs {}; struct FooHooksInterface : public at::PrivateUse1HooksInterface { FooHooksInterface(FooHooksArgs) {} ~FooHooksInterface() override = default; - const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) override { + const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override { static auto device_gen = make_generator_privateuse1(device_index); return device_gen; } diff --git a/test/cpp_extensions/open_registration_extension/README.md b/test/cpp_extensions/open_registration_extension/README.md index f304a96ecb1a87..07f1f98d915a76 100644 --- a/test/cpp_extensions/open_registration_extension/README.md +++ b/test/cpp_extensions/open_registration_extension/README.md @@ -1,4 +1,4 @@ -This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core. +This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core. ## How to use Install as standalone with `python setup.py develop` (or install) from this folder. @@ -8,6 +8,23 @@ You can run test via `python test/test_openreg.py`. For simplicity anything that can be implemented from python is done so. A real implementation will most likely want to call these different APIs from c++ directly. -The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn. +The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing. -Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate. +The codebase is split as follows: +- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module. +- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation +- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth. +- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor. + +## Next steps + +Currently, the autograd test is disabled because it's missing the getStream implementation. +The main next step would be to: +- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations. +- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver. +- Add RNG Generator. +- Add Pinned memory and HostAllocator. + +Longer term: +- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this. +- Build this module in the CI environment and enable Device-generic tests on this device. diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py index fa2a286f22c91b..fa231cff5b9d36 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py @@ -1,26 +1,13 @@ import torch - -# Global properties of our device -NUM_DEVICES = 7 - # Create our python implementation dict so that the C++ module # can access it during its initialization -_IMPL_REGISTRY = {} - -# Load the C++ Module -import pytorch_openreg._C # noqa: F401 - - -# Define all the implementations in the registry -def register(fn): - _IMPL_REGISTRY[fn.__name__[1:]] = fn - return fn +# Also register aten impls +from ._aten_impl import _IMPL_REGISTRY as _IMPL_REGISTRY # noqa: F401 -@register -def _deviceCount(): - return NUM_DEVICES +# Load the C++ Module +import pytorch_openreg._C # noqa: F401 # usort: skip # Module used for our backend @@ -31,15 +18,3 @@ class _OpenRegMod: # Set all the appropriate state on PyTorch torch.utils.rename_privateuse1_backend("openreg") torch._register_device_module("openreg", _OpenRegMod()) - -_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901 - - -def _openreg_kernel_fallback(op, *args, **kwargs): - print("Calling ", op) - assert op is torch.ops.aten.empty.memory_format - # FIXME: this returns a cpu Tensor which is NOT ok. - return torch.empty(args[0]) - - -_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py new file mode 100644 index 00000000000000..7103655185ba89 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -0,0 +1,154 @@ +import logging + +import torch +from torch.utils._pytree import tree_any + + +log = logging.getLogger(__name__) + +from ._device_daemon import driver +from ._meta_parser import prepare_for_sending, to_device_no_copy + + +_IMPL_REGISTRY = {} + + +# Define all the implementations in the registry +def _register_same_name(name, with_log=False): + def _(*args, **kwargs): + if with_log: + log.info("Calling hook %s", name) + return driver.exec(name, *args, **kwargs) + + _IMPL_REGISTRY[name] = _ + + +_register_same_name("deviceCount") +_register_same_name("getDevice") +_register_same_name("uncheckedSetDevice") +_register_same_name("exchangeDevice") +_register_same_name("malloc", True) +_register_same_name("free", True) + +_openreg_lib = torch.library.Library("_", "IMPL") + + +def _openreg_kernel_fallback(op, *args, **kwargs): + log.info("Calling kernel %s", op) + + # Special ops needed to avoid infinite recursion + if op is torch.ops.aten._copy_from.default: + from_, to_ = args + if from_.device.type == to_.device.type: + assert from_.device.type == "openreg" + op = torch.ops.aten.copy_.default + args = to_, from_ + # handled below as a regular copy + elif from_.device.type == "openreg": + args, _ = prepare_for_sending((from_,), {}) + host_mem = driver.exec("send_data", *args) + return to_.copy_(host_mem) + elif to_.device.type == "openreg": + args, _ = prepare_for_sending((to_,), {}) + driver.exec("recv_data", from_, *args) + return to_ + else: + raise RuntimeError("Should not happen") + elif op is torch.ops.aten.set_.source_Tensor: + return torch.ops.aten.set_.source_Storage_storage_offset( + args[0], + args[1].untyped_storage(), + args[1].storage_offset(), + args[1].size(), + args[1].stride(), + ) + elif op is torch.ops.aten._local_scalar_dense.default: + args, _ = prepare_for_sending(args, {}) + host_mem = driver.exec("send_data", *args) + return host_mem.item() + + op_name = None + post_process = None + if "out" in op._overloadname: + # Note that all structured native op will call here + if isinstance(kwargs["out"], tuple): + raise RuntimeError(f"out= variant {op} with tuple out= not supported") + if kwargs["out"].nelement() == 0: + # Out variant that needs a resize, convert to an out of place + # and handle generically below + orig_out = kwargs["out"] + del kwargs["out"] + if op._overloadname != "out": + raise RuntimeError( + "Cannot retranslate non-default out= variant form 0 size" + ) + op = op.overloadpacket.default + + def _post_process(): + nonlocal real_res + orig_out.set_(real_res) + real_res = orig_out + + post_process = _post_process + + else: + # No metadata update to do, just run the op on the device + op_name = op.overloadpacket._qualified_op_name + real_res = kwargs["out"] + elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): + # No Tensor argument means factory function + # They should decompose and be handled in our c++ side directly + raise RuntimeError(f"{op} not handled yet.") + elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: + # Only handle inplace ops returning their first arg + assert len(args) >= 1, f"Inplace {op} needs at least one arg" + assert ( + len(op._schema.returns) == 1 + ), f"NYI Inplace {op} with more than one return" + op_name = op.overloadpacket._qualified_op_name + real_res = args[0] + elif any(r.alias_info is not None for r in op._schema.returns): + # View ops + if op is torch.ops.aten.view.default: + return torch.ops.aten._unsafe_view(*args, **kwargs) + raise RuntimeError(f"{op} view op is not handled yet") + + if op_name is None: + # 1. Compute updated metadata + if torch.Tag.dynamic_output_shape not in op.tags: + # Usual case: run the meta op to see the output metadata + meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) + meta_res = op(*meta_args, **meta_kwargs) + + # 2. Allocate the output + real_res, _ = to_device_no_copy("openreg", meta_res, {}) + else: + # Slow version for data-dependent functions: + # Run the op on the device just to get the output shape + args_, kwargs_ = prepare_for_sending(args, kwargs) + shape = driver.exec( + "get_op_output_shape", + op.overloadpacket._qualified_op_name, + args_, + kwargs_, + ) + + # 2. Allocate the output + real_res = args[0].new(shape) + + # 3. Move to out variant + kwargs["out"] = real_res + # Let overload resolution find the out= overload + op_name = op.overloadpacket._qualified_op_name + + # 4. Run the compute and populate the output on the device + args, kwargs = prepare_for_sending(args, kwargs) + driver.exec("run_op", op_name, args, kwargs) + + if post_process is not None: + post_process() + + return real_res + + +_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py new file mode 100644 index 00000000000000..3b6ee8638939bb --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py @@ -0,0 +1,198 @@ +import logging + +import torch + +from ._meta_parser import ( + OpenRegTensorData, + receive_after_sending, + safe_str, + validate_send_queue_args, +) + + +log = logging.getLogger(__name__) +mp_context = torch.multiprocessing.get_context("spawn") + + +# Our allocator +class Allocator: + def __init__(self): + self.allocated = {} + + def malloc(self, size): + new_data = torch.empty(size, dtype=torch.uint8) + ptr = new_data.data_ptr() + self.allocated[ptr] = new_data + return ptr + + def free(self, ptr): + if ptr not in self.allocated: + return False + else: + del self.allocated[ptr] + return True + + def tensor_from_meta(self, meta): + # Usual case, we're receiving a known Tensor + found_base = self.allocated.get(meta.data_ptr, None) + + # Might be a rewrap of another storage at a different offset + # Slow path to try and find the corresponding storage + if found_base is None: + for tag, t in self.allocated.items(): + # t is always a 1D uint8 storage! + if meta.data_ptr > tag and meta.data_ptr < tag + t.nelement(): + # Blame @ngimel for this + slice_size = t.nelement() - (meta.data_ptr - tag) + found_base = torch.tensor((), dtype=torch.uint8).set_( + t.untyped_storage()[meta.data_ptr - tag :], + size=(slice_size,), + stride=(1,), + storage_offset=0, + ) + + # This pointer is not allocated here, segfault ! + if found_base is None: + log.info("Currently allocated blocks:\n %s", safe_str(self.allocated)) + log.info("Trying to access %s", meta) + raise RuntimeError("SEGFAULT!") + + # Raw 1d uint8 data + raw = found_base + # Slice the right storage part + raw_slice = raw.narrow(0, 0, meta.nelem_in_bytes) + # Reinterpret cast in the right dtype + as_dtype = raw_slice.view(dtype=meta.dtype) + # View to the right shape/stride/offset + view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset) + return view + + +def register(registry): + def func(fn): + registry[fn.__name__] = fn + return fn + + return func + + +class Driver: + def __init__(self): + super().__init__() + self.is_initialized = False + + def _lazy_init(self): + if self.is_initialized: + return + + # State of our driver + self.curr_device_idx = 0 + self.curr_stream = 0 + # Constant properties of our device + self.num_devices = 7 + + self.req_queue = mp_context.Queue() + self.ans_queue = mp_context.Queue() + + self.runner = mp_context.Process( + target=_Executor().run_forever, + args=(self.req_queue, self.ans_queue), + daemon=True, + ) + self.runner.start() + self.is_initialized = True + + def exec(self, cmd, *args): + self._lazy_init() + log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) + + if cmd in Driver.registry: + res = Driver.registry[cmd](self, *args) + else: + validate_send_queue_args(cmd, args) + self.req_queue.put((cmd,) + args) + res = self.ans_queue.get() + + log.info("Main process result for %s received: %s", cmd, safe_str(res)) + if res == "ERROR": + raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") + else: + return res + + registry = {} + + @register(registry) + def deviceCount(self, *args): + assert len(args) == 0 + return self.num_devices + + @register(registry) + def getDevice(self): + return self.curr_device_idx + + @register(registry) + def uncheckedSetDevice(self, *args): + assert len(args) == 1 + self.curr_device_idx = int(args[0]) + + @register(registry) + def exchangeDevice(self, *args): + assert len(args) == 1 + res = self.curr_device_idx + self.curr_device_idx = int(args[0]) + return res + + +class _Executor: + def __init__(self): + self.allocator = Allocator() + + def run_forever(self, req_queue, ans_queue): + # Serve all requests + while True: + cmd, *args = req_queue.get() + log.info("Worker executing: %s", cmd) + if cmd in _Executor.registry: + res = _Executor.registry[cmd](self, *args) + else: + log.warning("Bad command in worker") + res = "ERROR" + + log.info("Worker answering to: %s", cmd) + ans_queue.put(res) + + registry = {} + + @register(registry) + def malloc(self, size): + return self.allocator.malloc(size) + + @register(registry) + def free(self, ptr): + return self.allocator.free(ptr) + + def _run_op(self, op_name, args, kwargs): + op, _ = torch._C._jit_get_operation(op_name) + args, kwargs = receive_after_sending(self.allocator, args, kwargs) + return op(*args, **kwargs) + + @register(registry) + def run_op(self, op_name, args, kwargs): + self._run_op(op_name, args, kwargs) + + @register(registry) + def get_op_output_shape(self, op_name, args, kwargs): + return self._run_op(op_name, args, kwargs).size() + + @register(registry) + def send_data(self, *args): + assert len(args) == 1 + return OpenRegTensorData.from_meta(self.allocator, args[0]) + + @register(registry) + def recv_data(self, host_tensor, dev_mem): + dev_tensor = OpenRegTensorData.from_meta(self.allocator, dev_mem) + dev_tensor.copy_(host_tensor) + + +driver = Driver() diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py new file mode 100644 index 00000000000000..18b3c1842fc027 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py @@ -0,0 +1,104 @@ +import pprint + +import torch +from torch.utils._pytree import tree_map, tree_map_only + + +class OpenRegTensorMeta: + def __init__(self, tensor, checked=True): + if checked and not tensor.device.type == "openreg": + raise RuntimeError( + "Creating OpenRegTensorMeta is only for Tensors on openreg device" + ) + self.data_ptr = tensor.untyped_storage().data_ptr() + self.size = tensor.size() + self.stride = tensor.stride() + self.storage_offset = tensor.storage_offset() + self.dtype = tensor.dtype + self.nelem_in_bytes = tensor.nelement() * tensor.element_size() + + def __repr__(self): + return ( + f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, " + f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})" + ) + + +class OpenRegTensorData(torch.Tensor): + @staticmethod + def from_meta(allocator, tensor_meta): + return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta)) + + +VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float} + +VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str} + + +def safe_str(args): + def convert(obj): + if isinstance(obj, torch.Tensor): + return str(OpenRegTensorMeta(obj, checked=False)) + else: + return obj + + new_args = tree_map(convert, args) + return pprint.pformat(new_args) + + +def validate_send_queue_args(cmd, args): + def check(obj): + if type(obj) not in VALID_QUEUE_TYPES_OUT: + if ( + cmd == "recv_data" + and type(obj) is torch.Tensor + and obj.device.type == "cpu" + ): + # Only HtoD copy command can send cpu Tensors over + return + raise RuntimeError( + f"Trying to send invalid object through queue: {type(obj)}" + ) + + tree_map(check, args) + + +def prepare_for_sending(args, kwargs): + def convert(obj): + if type(obj) not in VALID_QUEUE_TYPES_IN: + raise RuntimeError( + f"Cannot send object of type {type(obj)} " "over openreg device pipe." + ) + + if isinstance(obj, torch.Tensor): + return OpenRegTensorMeta(obj) + else: + return obj + + return tree_map(convert, (args, kwargs)) + + +def receive_after_sending(allocator, args, kwargs): + def convert(obj): + if type(obj) not in VALID_QUEUE_TYPES_OUT: + raise RuntimeError( + f"Received invalid object of type {type(obj)} " + "over openreg device pipe." + ) + + if isinstance(obj, OpenRegTensorMeta): + return allocator.tensor_from_meta(obj) + else: + return obj + + return tree_map(convert, (args, kwargs)) + + +def to_device_no_copy(device, args, kwargs): + def safe_to(t): + if device == "meta": + return t.to(device=device) + else: + return torch.empty_like(t, device=device) + + return tree_map_only(torch.Tensor, safe_to, (args, kwargs)) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h index 68bfeef1fdc889..6cf2b429bea076 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h @@ -6,5 +6,6 @@ namespace openreg { void set_impl_registry(PyObject* registry); + py::function get_method(const char* name); } \ No newline at end of file diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp index 238ed8fddaf4ba..dae7b3b1db9608 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp @@ -11,10 +11,6 @@ namespace { // Python dictionary where real implementations can be found PyObject* py_registry; -py::function get_method(const char* name) { - return py::cast(py_registry)[name]; -} - // C++ hooks implementation struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {}; @@ -243,4 +239,12 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); void set_impl_registry(PyObject* registry) { py_registry = registry; } + +py::function get_method(const char* name) { + auto dict = py::cast(py_registry); + TORCH_CHECK(dict.contains(name), "OpenReg registry does not contain ", + "an implementation for '", name, "' make sure to add it in the __init__.py " + "file and register it.") + return dict[name]; +} } // openreg \ No newline at end of file diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp new file mode 100644 index 00000000000000..91b2dfcec4b4cd --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp @@ -0,0 +1,125 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace openreg { + +namespace { + +using openreg_ptr_t = uint64_t; + +// A dummy allocator for our custom device, that secretly uses the CPU +struct OpenRegAllocator final : at::Allocator { + OpenRegAllocator() = default; + + at::DataPtr allocate(size_t nbytes) override { + py::gil_scoped_acquire acquire; + auto curr_device_idx = get_method("getDevice")().cast(); + auto curr_device = c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx); + void* data = nullptr; + if (nbytes > 0) { + data = reinterpret_cast(get_method("malloc")(nbytes).cast()); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + py::gil_scoped_acquire acquire; + TORCH_CHECK( + get_method("free")(reinterpret_cast(ptr)).cast(), + "Failed to free memory pointer at ", ptr + ); + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + py::gil_scoped_acquire acquire; + get_method("copy_data")(reinterpret_cast(dest), reinterpret_cast(src), count); + } +}; + +// Register our dummy allocator +static OpenRegAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + + +// Empty op needs C++ code and cannot be handled by python side fallback +at::Tensor empty_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported"); + TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_generic( + size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt); +} + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK(c10::layout_or_default(layout_opt) == c10::Layout::Strided, "Non strided layout not supported"); + TORCH_CHECK(!c10::pinned_memory_or_default(pin_memory_opt), "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + return at::detail::empty_strided_generic( + size, stride, &global_openreg_alloc, pu1_dks, dtype); +} + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional storage_offset_) { + // Metadata-only change so we re-use the cpu impl + return at::cpu::as_strided(self, size, stride, storage_offset_); +} + +at::Tensor& set_openreg( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", empty_openreg); + m.impl("empty_strided", empty_strided_openreg); + m.impl("as_strided", as_strided_openreg); + m.impl("set_.source_Storage_storage_offset", set_openreg); +} + +} // anonymous namspaces + +} // openreg diff --git a/test/cpp_extensions/open_registration_extension/test/test_openreg.py b/test/cpp_extensions/open_registration_extension/test/test_openreg.py index 87c35edf695ca3..27689c7559aca7 100644 --- a/test/cpp_extensions/open_registration_extension/test/test_openreg.py +++ b/test/cpp_extensions/open_registration_extension/test/test_openreg.py @@ -7,17 +7,17 @@ import pytorch_openreg import torch -from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, TestCase class TestOpenReg(TestCase): def test_initializes(self): self.assertEqual(torch._C._get_privateuse1_backend_name(), "openreg") - @unittest.skipIf(not IS_LINUX, "Only works on linux") + @unittest.SkipTest def test_autograd_init(self): # Make sure autograd is initialized - torch.rand(2, requires_grad=True, device="openreg").sum().backward() + torch.ones(2, requires_grad=True, device="openreg").sum().backward() pid = os.getpid() task_path = f"/proc/{pid}/task" @@ -30,9 +30,39 @@ def test_autograd_init(self): thread_name = file.read().strip() all_thread_names.add(thread_name) - for i in range(pytorch_openreg.NUM_DEVICES): + for i in range(pytorch_openreg._device_daemon.NUM_DEVICES): self.assertIn(f"pt_autograd_{i}", all_thread_names) + def test_factory(self): + a = torch.empty(50, device="openreg") + self.assertEqual(a.device.type, "openreg") + + a.fill_(3.5) + + self.assertTrue(a.eq(3.5).all()) + + def test_printing(self): + a = torch.ones(20, device="openreg") + # Does not crash! + str(a) + + def test_cross_device_copy(self): + a = torch.rand(10) + b = a.to(device="openreg").add(2).to(device="cpu") + self.assertEqual(b, a + 2) + + def test_copy_same_device(self): + a = torch.ones(10, device="openreg").clone() + self.assertEqual(a, torch.ones(10, device="openreg")) + + def test_data_dependent_output(self): + cpu_a = torch.randn(10) + a = cpu_a.to(device="openreg") + mask = a.gt(0) + out = torch.masked_select(a, mask) + + self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0))) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py index 521b20ca5b0bdf..636c98ec5ba590 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py @@ -8,8 +8,8 @@ import torch.nn as nn from torch.distributed._composable import replicate from torch.distributed._composable.fsdp import fully_shard -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest, MLPStack from torch.testing._internal.common_utils import run_tests diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 6e92ce3b36e62e..5641082a8063cd 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -32,9 +32,9 @@ from torch.distributed._composable.fsdp._fsdp_param import ShardedState from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup from torch.distributed._tensor import DTensor -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed._tensor.experimental import implicit_replication from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index b187c867694171..ecd1e7a6d9aa27 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -4,7 +4,10 @@ import contextlib import copy import functools +import itertools +import logging import unittest +from collections import defaultdict from unittest import mock import torch @@ -30,14 +33,20 @@ from torch.utils._triton import has_triton -def _is_op_in_graph(graph, op): - return any(node.target is op for node in graph.nodes) +log = logging.getLogger(__name__) + + +def _count_op_in_graph(graph, op): + return sum(1 for node in graph.nodes if node.target is op) def _is_fallback_op_in_snodes(snodes, op): return any(is_fallback_op(snode.node, op) for snode in snodes) +orig_F_scaled_dot_product_attention = F.scaled_dot_product_attention + + class TestFullyShardCompileCompute(FSDPTest): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(2) @@ -130,7 +139,7 @@ def f(x): self.assertEqual(cnt.op_count, 1) self.assertEqual(len(cnt.graphs), 1) - def test_trace_fsdp_set_(self): + def test_trace_fsdp_copy_(self): @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"}) def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None: torch.add(x, 1, out=out) @@ -140,7 +149,7 @@ def f(x): buf_view = buf.view(-1) torch.ops.mylib.add_one_out(x, out=buf_view) buf_view2 = buf.view(-1) - torch.ops.fsdp.set_(x, buf_view2) + torch.ops.fsdp.copy_(x, buf_view2) ref_x = torch.zeros(2) x = copy.deepcopy(ref_x) @@ -148,26 +157,119 @@ def f(x): torch.compile(f, backend="aot_eager")(x) self.assertEqual(x, ref_x) + def _assert_no_aliased_unsharded_params_in_graph_inputs( + self, model, graph: torch.fx.Graph + ) -> None: + # FSDP2 unsharded params are mutated in the graph without going through functionalization. + # Therefore, we want to make sure they don't have aliases in the graph inputs, to make it easier + # for us to do the replacement of unsharded params with the all-gathered temporary buffer directly + # in downstream users in the graph. + storage_id_to_graph_inputs = defaultdict(list) + unsharded_param_graph_inputs = set() + for node in graph.nodes: + if ( + node.op == "call_function" + and node.target + in [ + torch.ops.inductor.resize_storage_bytes_.default, + torch.ops.fsdp.copy_.default, + ] + and node.args[0].op == "placeholder" + ): + unsharded_param_graph_inputs.add(node.args[0]) + assert len(unsharded_param_graph_inputs) > 0 + assert len(unsharded_param_graph_inputs) == len( + list(model.parameters()) + ), """\ +Expected all model parameters to be wrapped by FSDP2 and +have their unsharded version as graph input, but it's not true! +""" + no_aliased_unsharded_params_in_graph_inputs = True + err_msg = "" + for aliased_graph_inputs in storage_id_to_graph_inputs.values(): + if len(aliased_graph_inputs) > 1 and any( + x in unsharded_param_graph_inputs for x in aliased_graph_inputs + ): + no_aliased_unsharded_params_in_graph_inputs = False + err_msg += f"""\n +Found aliased unsharded param in graph inputs: {aliased_graph_inputs}, +val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, +""" + self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg) + + def _remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks( + self, model, fullgraph + ): + def _run_with_checks(graph, orig_fn): + self._assert_no_aliased_unsharded_params_in_graph_inputs(model, graph) + orig_fn(graph) + + if fullgraph: + return mock.patch.object( + comms, + "remove_fsdp2_unsharded_param_graph_input_usage", + functools.partial( + _run_with_checks, + orig_fn=comms.remove_fsdp2_unsharded_param_graph_input_usage, + ), + ) + else: + return contextlib.nullcontext() + + def _check_fsdp_copy_and_resize_ops_count_in_graph( + self, + graph, + *, + fwd_copy_count, + fwd_resize_count, + bwd_copy_count, + bwd_resize_count, + ): + def _check_count(copy_count, resize_count): + actual_copy_count = _count_op_in_graph(graph, torch.ops.fsdp.copy_.default) + self.assertEqual( + actual_copy_count, + copy_count, + f"Unexpected number of `fsdp.copy_` ops (expected {copy_count}, got {actual_copy_count}) in graph: {graph}", + ) + + actual_resize_count = _count_op_in_graph( + graph, torch.ops.inductor.resize_storage_bytes_.default + ) + self.assertEqual( + actual_resize_count, + resize_count, + f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950 + ) + + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: + _check_count(fwd_copy_count, fwd_resize_count) # fwd graph + else: + _check_count(bwd_copy_count, bwd_resize_count) # bwd graph + def _reinplace_all_gather_with_optional_checks(self, fullgraph): def _run_with_checks(graph, orig_fn): - self.assertTrue( - _is_op_in_graph( - graph, - torch.ops._c10d_functional.all_gather_into_tensor.default, - ) + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, ) + orig_fn(graph) - self.assertFalse( - _is_op_in_graph( - graph, - torch.ops._c10d_functional.all_gather_into_tensor.default, - ) + + self.assertEqual( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor.default + ), + 0, ) - self.assertTrue( - _is_op_in_graph( - graph, - torch.ops._c10d_functional.all_gather_into_tensor_out.default, - ) + + self.assertGreater( + _count_op_in_graph( + graph, torch.ops._c10d_functional.all_gather_into_tensor_out.default + ), + 0, ) if fullgraph: @@ -266,8 +368,6 @@ def inductor_code_check_fsdp_all_gather( self, file_check, overlapped_compute_op_str, - num_resize, - num_set, last_all_gather=False, ): file_check = file_check.check("torch.ops.fsdp.all_gather_copy_in.") @@ -278,16 +378,9 @@ def inductor_code_check_fsdp_all_gather( # Checks that AGWait is delayed, making the AG overlap with some compute op. if overlapped_compute_op_str is not None: file_check = file_check.check(f"{overlapped_compute_op_str}") - file_check = file_check.check_count( - "inductor_ops.resize_storage_bytes_(", num_resize, exactly=True - ) file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") file_check = self.inductor_code_check_no_compute_op(file_check) file_check = file_check.check("torch.ops.fsdp.split_with_sizes_copy.") - file_check = self.inductor_code_check_no_compute_op(file_check) - file_check = file_check.check_count( - "torch.ops.aten.set_.", num_set, exactly=True - ) if not last_all_gather: # Checks that there is no compute op between this AGWait and next AG. file_check = self.inductor_code_check_no_compute_op(file_check) @@ -307,22 +400,12 @@ def inductor_code_check_fsdp_reduce_scatter( file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") return file_check - @torch._dynamo.config.patch( - inline_inbuilt_nn_modules=True, - skip_fsdp_hooks=False, - ) - @torch._functorch.config.patch(recompute_views=True) - @torch._functorch.config.patch(cse=False) - @torch._inductor.config.patch( - reorder_for_compute_comm_overlap=True, - reorder_for_compute_comm_overlap_passes=[ - "sink_waits", - "raise_comms", - "reorder_compute_for_overlap", - ], - ) def _test_traceable_fsdp( - self, model_init_fn, input_creation_fn, backend, fullgraph + self, + model_init_fn, + input_creation_fn, + backend, + fullgraph, ): def compiler_fn(compiled_autograd_backend): def _fn(gm): @@ -334,7 +417,12 @@ def _fn(gm): return _fn - def run_iters(model, optim, n_iter=10, compiled_autograd_backend=None): + def run_iters( + model, + optim, + n_iter=10, + compiled_autograd_backend=None, + ): torch.manual_seed(42) losses = [] for i in range(n_iter): @@ -359,9 +447,18 @@ def test_compiled(): # FSDP2 does lazy init using 1st run, so run it once to init using eager mode run_iters(model, optim, n_iter=1) - model_compiled = torch.compile(model, backend=backend, fullgraph=fullgraph) - res = run_iters(model_compiled, optim, compiled_autograd_backend=backend) - return res + with self._remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks( + model, fullgraph + ): + model_compiled = torch.compile( + model, backend=backend, fullgraph=fullgraph + ) + res = run_iters( + model_compiled, + optim, + compiled_autograd_backend=backend, + ) + return res def test_eager(): model, optim = model_init_fn() @@ -371,7 +468,29 @@ def test_eager(): res = run_iters(model, optim) return res - losses_compiled = test_compiled() + torch._dynamo.reset() + torch._dynamo.compiled_autograd.reset() + with torch._dynamo.config.patch( + # NOTE: Setting fullgraph=False for forward (to allow graph-breaks) is a common scenario + # and in that case we need a standalone Compiled Autograd ctx that has fullgraph=True for backward. + # Hence here we explicitly set compiled_autograd=False and use the standalone Compiled Autograd ctx + # `maybe_compiled_autograd_ctx` created in `run_iters()`. + compiled_autograd=False, + inline_inbuilt_nn_modules=True, + skip_fsdp_hooks=False, + ), torch._functorch.config.patch( + enable_autograd_cache=False, + recompute_views=True, + ), torch._inductor.config.patch( + force_disable_caches=True, + reorder_for_compute_comm_overlap=True, + reorder_for_compute_comm_overlap_passes=[ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ], + ): + losses_compiled = test_compiled() losses_eager = test_eager() if not self.fake_pg: for loss_compiled, loss_eager in zip(losses_compiled, losses_eager): @@ -448,9 +567,9 @@ def __init__(self, hidden_dim): ) def forward(self, x): + ret = torch.matmul(x, self.param1) if not fullgraph: torch._dynamo.graph_break() - ret = torch.matmul(x, self.param1) ret = ret * self.param2 ret = torch.relu(ret) return ret @@ -515,11 +634,23 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): @skipIfRocm @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - def test_nested_fully_shard_backend_inductor(self): - for fullgraph in [True, False]: + def test_nested_fully_shard_backend_inductor_fullgraph_True(self): + for fullgraph in [True]: with self._reinplace_all_gather_with_optional_checks( fullgraph - ), self._maybe_run_decide_global_ordering_of_comms_with_checks(fullgraph): + ), self._maybe_run_decide_global_ordering_of_comms_with_checks( + fullgraph + ), torch._inductor.config.patch( + post_grad_custom_post_pass=functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + fwd_copy_count=0, + fwd_resize_count=0, + bwd_copy_count=0, + bwd_resize_count=0, + ) + if fullgraph + else None + ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns( @@ -530,53 +661,38 @@ def test_nested_fully_shard_backend_inductor(self): ) ) if fullgraph: - self.assertTrue( - len(triton_codes) == 2, + self.assertEqual( + len(triton_codes), + 2, "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph", ) fwd_code = triton_codes[0] file_check = FileCheck().check("def call(args):") for fwd_ag_block_info in [ - dict(overlapped_compute_op_str=None, num_resize=0, num_set=2), + dict(overlapped_compute_op_str=None), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=2, - num_set=2, last_all_gather=True, ), ]: @@ -588,16 +704,12 @@ def test_nested_fully_shard_backend_inductor(self): bwd_code = triton_codes[1] file_check = FileCheck().check("def call(args):") for bwd_ag_block_info in [ - dict(overlapped_compute_op_str=None, num_resize=0, num_set=2), + dict(overlapped_compute_op_str=None), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=0, - num_set=2, ), dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=0, - num_set=2, last_all_gather=True, ), ]: @@ -605,7 +717,7 @@ def test_nested_fully_shard_backend_inductor(self): file_check, **bwd_ag_block_info ) for bwd_rs_block_info in [ - dict(overlapped_compute_op_str="extern_kernels.mm("), + dict(overlapped_compute_op_str="extern_kernels.addmm("), dict( overlapped_compute_op_str=None ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None @@ -615,17 +727,31 @@ def test_nested_fully_shard_backend_inductor(self): file_check, **bwd_rs_block_info ) file_check.run(bwd_code) - else: - # TODO: when fullgraph=False and there is graph break in FWD graph, - # there are several recompiles, need to figure out why. - self.assertTrue( - len(triton_codes) > 2, - "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", - ) - def _create_transformer_factory_fns(self): + @skipIfRocm + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_nested_fully_shard_backend_inductor_fullgraph_False(self): + _, triton_codes = run_and_get_code( + lambda: self._test_traceable_fsdp( + *self._create_nested_fully_shard_factory_fns(fullgraph=False), + "inductor", + fullgraph=False, + ) + ) + # TODO: when fullgraph=False and there is graph break in FWD graph, + # there are several recompiles, need to figure out why. + self.assertGreater( + len(triton_codes), + 2, + "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", + ) + + def _create_transformer_factory_fns( + self, all_requires_grad, *, activation_checkpoint=False + ): seq_len = 16 vocab_size = 8 + n_layers = 3 def model_init_fn(): torch.manual_seed(self.rank) @@ -633,9 +759,21 @@ def model_init_fn(): mesh = init_device_mesh("cuda", (self.world_size,)) model_args = ModelArgs( vocab_size=vocab_size, - n_layers=3, + n_layers=n_layers, + checkpoint_activations=activation_checkpoint, ) model = Transformer(model_args) + if not all_requires_grad: + requires_grad_params = ["attention.wq", "attention.wv"] + requires_grad_param_count = 0 + for k, v in model.named_parameters(): + for substring in requires_grad_params: + if substring in k: + v.requires_grad_(True) + requires_grad_param_count += 1 + else: + v.requires_grad_(False) + assert requires_grad_param_count == n_layers * len(requires_grad_params) for layer_id, mod in enumerate(model.layers): fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config) model = fully_shard( @@ -654,30 +792,32 @@ def input_creation_fn(): return model_init_fn, input_creation_fn def _maybe_add_graph_break_to_sdpa(self, fullgraph): - def _sdpa_with_graph_break(orig_fn, fullgraph, *args, **kwargs): - if not fullgraph: - torch._dynamo.graph_break() - return orig_fn(*args, **kwargs) - - return mock.patch.object( - F, - "scaled_dot_product_attention", - functools.partial( + def _sdpa_with_graph_break(*args, **kwargs): + torch._dynamo.graph_break() + return orig_F_scaled_dot_product_attention(*args, **kwargs) + + if not fullgraph: + return mock.patch.object( + F, + "scaled_dot_product_attention", _sdpa_with_graph_break, - F.scaled_dot_product_attention, - fullgraph, - ), - ) + ) + else: + return contextlib.nullcontext() @skipIfRocm @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") def test_transformer_backend_aot_eager(self): - for fullgraph in [True, False]: + for fullgraph, all_requires_grad in itertools.product( + [True, False], [True, False] + ): with self._maybe_add_graph_break_to_sdpa( fullgraph ), self._reinplace_all_gather_with_optional_checks(fullgraph): self._test_traceable_fsdp( - *self._create_transformer_factory_fns(), + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad + ), "aot_eager", fullgraph=fullgraph, ) @@ -687,10 +827,14 @@ def test_transformer_backend_aot_eager(self): # TODO: native_dropout has worse accuracy after decomp, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_aot_eager_decomp_partition(self): - for fullgraph in [True, False]: + for fullgraph, all_requires_grad in itertools.product( + [True, False], [True, False] + ): with self._maybe_add_graph_break_to_sdpa(fullgraph): self._test_traceable_fsdp( - *self._create_transformer_factory_fns(), + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad + ), "aot_eager_decomp_partition", fullgraph=fullgraph, ) @@ -699,45 +843,63 @@ def test_transformer_backend_aot_eager_decomp_partition(self): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) - def test_transformer_backend_inductor(self): - for fullgraph in [True, False]: - with self._maybe_add_graph_break_to_sdpa( - fullgraph - ), self._reinplace_all_gather_with_optional_checks( + def test_transformer_backend_inductor_fullgraph_True(self): + for fullgraph, all_requires_grad, activation_checkpoint in itertools.product( + [True], [True, False], [True, False] + ): + log.warning( + f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001 + ) + with self._reinplace_all_gather_with_optional_checks( fullgraph ), self._maybe_run_decide_global_ordering_of_comms_with_checks( fullgraph + ), torch._inductor.config.patch( + post_grad_custom_post_pass=functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + # NOTE: For the root unsharded params, we don't reshard after forward since for training, + # the parameters would be freed and all-gathered immediately. Hence we still have + # their resize and copy ops in the graph. + fwd_copy_count=4, + fwd_resize_count=4, + bwd_copy_count=0, + bwd_resize_count=4, + ) + if fullgraph + else None ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( - *self._create_transformer_factory_fns(), + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad, + activation_checkpoint=activation_checkpoint, + ), "inductor", fullgraph=fullgraph, ) ) if fullgraph: - self.assertTrue( - len(triton_codes) == 2, + self.assertEqual( + len(triton_codes), + 2, "Expected two separate lowerings to Triton code, one from FWD graph and one from Compiled Autograd BWD graph", ) fwd_code = triton_codes[0] file_check = FileCheck().check("def call(args):") for fwd_ag_block_info in [ - dict(overlapped_compute_op_str="triton_", num_resize=0, num_set=4), + dict( + overlapped_compute_op_str="triton_" + if all_requires_grad + else None, + ), dict( overlapped_compute_op_str="aten.native_dropout.", - num_resize=0, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.", - num_resize=12, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention.", - num_resize=12, - num_set=12, last_all_gather=True, ), ]: @@ -751,43 +913,66 @@ def test_transformer_backend_inductor(self): for bwd_ag_block_info in [ dict( overlapped_compute_op_str="extern_kernels.mm(", - num_resize=0, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.", - num_resize=0, - num_set=12, ), dict( overlapped_compute_op_str="aten._scaled_dot_product_efficient_attention_backward.", - num_resize=0, - num_set=12, last_all_gather=True, ), ]: - file_check = self.inductor_code_check_fsdp_all_gather( - file_check, **bwd_ag_block_info - ) + if bwd_ag_block_info is not None: + file_check = self.inductor_code_check_fsdp_all_gather( + file_check, **bwd_ag_block_info + ) for bwd_rs_block_info in [ - dict(overlapped_compute_op_str="extern_kernels.mm("), + dict(overlapped_compute_op_str="extern_kernels.mm(") + if all_requires_grad + else None, dict( overlapped_compute_op_str=None ), # TODO: improve compute/comm overlap, so that `overlapped_compute_op_str` is not None dict(overlapped_compute_op_str=None), - dict(overlapped_compute_op_str=None), + dict(overlapped_compute_op_str=None) if all_requires_grad else None, ]: - file_check = self.inductor_code_check_fsdp_reduce_scatter( - file_check, **bwd_rs_block_info - ) + if bwd_rs_block_info is not None: + file_check = self.inductor_code_check_fsdp_reduce_scatter( + file_check, **bwd_rs_block_info + ) file_check.run(bwd_code) - else: - # TODO: when fullgraph=False and there is graph break in FWD graph, - # there are several recompiles, need to figure out why. - self.assertTrue( - len(triton_codes) > 2, - "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", + + @skipIfRocm + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + # TODO: native_dropout causes CUDA IMA error, need to figure out why + @torch._inductor.config.patch(fallback_random=True) + def test_transformer_backend_inductor_fullgraph_False(self): + fullgraph = False + # TODO: fix numerical issue in activation_checkpoint=True case + for all_requires_grad, activation_checkpoint in itertools.product( + [True, False], [False] + ): + log.warning( + f"fullgraph={fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001 + ) + with self._maybe_add_graph_break_to_sdpa(fullgraph): + _, triton_codes = run_and_get_code( + lambda: self._test_traceable_fsdp( + *self._create_transformer_factory_fns( + all_requires_grad=all_requires_grad, + activation_checkpoint=activation_checkpoint, + ), + "inductor", + fullgraph=fullgraph, + ) ) + # TODO: when fullgraph=False and there is graph break in FWD graph, + # there are several recompiles, need to figure out why. + self.assertGreater( + len(triton_codes), + 2, + "Expected at least 3 separate lowerings to Triton code, which means at least 1 graph break in FWD graph", + ) if __name__ == "__main__": diff --git a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py new file mode 100644 index 00000000000000..1d87e1b267d0d5 --- /dev/null +++ b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py @@ -0,0 +1,111 @@ +# Owner(s): ["oncall: distributed"] +import copy + +import torch +import torch.nn as nn +from torch.amp.grad_scaler import GradScaler, OptState +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._tensor import init_device_mesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest, MLP +from torch.testing._internal.common_utils import run_tests, skipIfRocm + + +class TestFullyShardGradientScaler(FSDPTest): + @skip_if_lt_x_gpu(4) + @skipIfRocm + def test_gradient_scaler(self): + self.run_subtests( + {"has_inf": [True, False], "test_2d": [True, False]}, + self._test_gradient_scaler, + ) + + def _test_gradient_scaler(self, has_inf: bool, test_2d: bool): + torch.manual_seed(0) + model = nn.Sequential( + *[nn.Linear(4, 4, device="cuda", bias=False) for _ in range(2)] + ) + for layer in model: + fully_shard(layer) + fully_shard(model) + input = torch.randn([4, 4], device="cuda") + + if test_2d: + mesh_2d = init_device_mesh( + "cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + ) + dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"] + model = nn.Sequential(MLP(2), MLP(2), MLP(2)) + tp_parallelize_plan = { + "0.in_proj": ColwiseParallel(), + "0.out_proj": RowwiseParallel(), + "1.in_proj": ColwiseParallel(), + "1.out_proj": RowwiseParallel(), + "2.in_proj": ColwiseParallel(), + "2.out_proj": RowwiseParallel(), + } + model = parallelize_module( + model, + device_mesh=tp_mesh, + parallelize_plan=tp_parallelize_plan, + ) + for module in model: + fully_shard(module, mesh=dp_mesh) + fully_shard(model, mesh=dp_mesh) + input = torch.randn((2,), device="cuda") + + loss = model(input).sum() + scaler = GradScaler(init_scale=2.0, enabled=True) + opt = torch.optim.Adam(model.parameters(), lr=1e-2) + scaler.scale(loss).backward() + inv_scale = scaler._scale.double().reciprocal().float() + if ( + has_inf is True + and opt.param_groups[0]["params"][0].grad._local_tensor.device.index == 1 + ): + opt.param_groups[0]["params"][0].grad._local_tensor[0, 0].fill_( + float("inf") + ) + inital_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() + + scaler.unscale_(opt) + for found_inf in scaler._per_optimizer_states[id(opt)][ + "found_inf_per_device" + ].values(): + self.assertEqual(found_inf, has_inf) + self.assertEqual( + scaler._per_optimizer_states[id(opt)]["stage"].value, + OptState.UNSCALED.value, + ) + unscaled_grad = opt.param_groups[0]["params"][0].grad.to_local().clone() + self.assertEqual(unscaled_grad, inital_grad * inv_scale) + initial_scale = scaler.get_scale() + initial_state = copy.copy(opt.state) + + scaler.step(opt) + steped_state = copy.copy(opt.state) + if has_inf: + # assert parameters are the same before/after + self.assertEqual(steped_state, initial_state) + else: + # new parameters here if no inf found during .unscale_() + self.assertNotEqual(steped_state.items(), initial_state.items()) + + scaler.update() + updated_scale = scaler.get_scale() + if has_inf: + # assert scale is updated + backoff_factor = scaler.get_backoff_factor() + self.assertEqual(updated_scale, initial_scale * backoff_factor) + else: + # scale is not updated + self.assertEqual(updated_scale, initial_scale) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 0ce3a029ca4318..c0c585f9d767c5 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -23,7 +23,6 @@ Replicate, Shard, ) -from torch.distributed._tensor.placement_types import _StridedShard from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp._init_utils import ( _init_inter_node_process_group, @@ -34,6 +33,7 @@ parallelize_module, RowwiseParallel, ) +from torch.distributed.tensor.placement_types import _StridedShard from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP from torch.testing._internal.common_utils import run_tests diff --git a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py index 46f22a12d9d703..3ed5853908298d 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py @@ -3,6 +3,7 @@ import copy import functools import unittest +from contextlib import nullcontext from typing import Dict import torch @@ -77,8 +78,21 @@ def _test_dp_state_dict_save_load(self, mlp_dim: int, mesh: DeviceMesh): @skip_if_lt_x_gpu(2) def test_dp_state_dict_cpu_offload(self): + self.run_subtests( + { + "offload_policy": [ + CPUOffloadPolicy(pin_memory=True), + CPUOffloadPolicy(pin_memory=False), + ], + "cpu_state_dict": [True, False], + }, + self._test_dp_state_dict_cpu_offload, + ) + + def _test_dp_state_dict_cpu_offload( + self, offload_policy: CPUOffloadPolicy, cpu_state_dict: bool + ): mlp_dim = 4 - offload_policy = CPUOffloadPolicy(pin_memory=True) torch.manual_seed(42) with torch.device("meta"): model = nn.Sequential( @@ -97,6 +111,8 @@ def test_dp_state_dict_cpu_offload(self): sharded_tensor = distribute_tensor( full_tensor, dtensor.device_mesh, dtensor.placements ) + if cpu_state_dict: + sharded_tensor = sharded_tensor.cpu() state_dicts.append({name: sharded_tensor}) # check that we can load with some parameters still on meta device @@ -105,11 +121,20 @@ def test_dp_state_dict_cpu_offload(self): # lazy init without error inp = torch.rand((mlp_dim, mlp_dim), device="cuda") - model(inp) - state_dict = model.state_dict() - for name, dtensor in state_dict.items(): - self.assertEqual(dtensor.device.type, "cpu") + context = ( + self.assertRaisesRegex( + RuntimeError, + r"Found following parameters on non-CPU device: \[\('0.weight', device\(type='cuda'", + ) + if not cpu_state_dict + else nullcontext() + ) + with context: + model(inp).sum() + state_dict = model.state_dict() + for name, dtensor in state_dict.items(): + self.assertEqual(dtensor.device.type, "cpu") def test_2d_state_dict_correctness(self): dp_size = 2 diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 66417c5c1edda4..ab52cb925709af 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -20,12 +20,12 @@ register_fsdp_forward_method, ) from torch.distributed._tensor import DTensor, init_device_mesh -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, apply_activation_checkpointing, ) from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -285,6 +285,7 @@ def test_train_parity_multi_group(self): "delay_before_all_gather": [False, True], "delay_before_reduce_scatter": [False, True], "delay_before_optim": [False, True], + "unshard_async_op": [False], }, self._test_train_parity_multi_group, ) @@ -307,6 +308,28 @@ def test_train_parity_multi_group_cpu_offload_eager(self): "delay_before_all_gather": [False, True], "delay_before_reduce_scatter": [False, True], "delay_before_optim": [False, True], + "unshard_async_op": [False], + }, + self._test_train_parity_multi_group, + ) + + @skip_if_lt_x_gpu(2) + @test_compiled_fsdp(compile_compute_on_module=Transformer) + def test_train_parity_multi_group_unshard_async_op(self): + """ + Tests train parity against DDP when using multiple parameter groups for + communication and setting ``unshard_async_op=True``. + """ + self.run_subtests( + { + "reshard_after_forward": [True], + "device_type": ["cuda"], + "offload_policy": [OffloadPolicy()], + "delay_after_forward": [False, True], + "delay_before_all_gather": [False, True], + "delay_before_reduce_scatter": [False, True], + "delay_before_optim": [False, True], + "unshard_async_op": [True], }, self._test_train_parity_multi_group, ) @@ -320,6 +343,7 @@ def _test_train_parity_multi_group( delay_before_all_gather: bool, delay_before_reduce_scatter: bool, delay_before_optim: bool, + unshard_async_op: bool, ): # Only test individual delays or all four delays to save test time if ( @@ -359,6 +383,8 @@ def _test_train_parity_multi_group( if isinstance(module, TransformerBlock): fully_shard_fn(module) fully_shard_fn(model) + if unshard_async_op: + model._set_unshard_async_op(unshard_async_op) optim = torch.optim.Adam(model.parameters(), lr=1e-2) delay_in_ms = 100 diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 5f2c9dbe1240b7..83b0f8f2b5ac62 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -14,11 +14,12 @@ from torch.distributed._composable import replicate from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, + set_model_state_dict, set_optimizer_state_dict, + StateDictOptions, ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -27,6 +28,7 @@ clean_tensor_name, ) from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -169,6 +171,68 @@ def _test_train_parity_2d_mlp( _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + @skipIfRocm + def test_train_parity_2d_transformer(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=3, dropout_p=0.0) + model = Transformer(model_args) + ref_model = copy.deepcopy(model).cuda() + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + + dp_size, tp_size = self.world_size // 2, 2 + global_mesh = init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) + + for layer in model.layers: + fully_shard(layer, mesh=global_mesh["dp"]) + fully_shard(model, mesh=global_mesh["dp"]) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + torch.manual_seed(42 + global_mesh.get_local_rank("dp")) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + for iter_idx in range(5): + ref_loss = ref_model(inp).sum() + loss = model(inp).sum() + self.assertEqual(ref_loss, loss) + ref_loss.backward() + loss.backward() + for param in ref_model.parameters(): + if param.grad is not None: + dist.all_reduce( + param.grad, + group=global_mesh.get_group("dp"), + op=dist.ReduceOp.AVG, + ) + + # Specially check the TP placement for `pos_embeddings.weight` and + # its which since the grad naturally has replicate placement, + # requiring FSDP to redistribute it to shard placement before FSDP + # runs its reduce-scatter + self.assertIsInstance(model.pos_embeddings.weight.placements[1], Shard) + self.assertIsInstance(model.pos_embeddings.weight.grad.placements[1], Shard) + for ref_param, (param_name, param) in zip( + ref_model.parameters(), model.named_parameters() + ): + full_grad = param.grad.full_tensor() + ref_grad = ref_param.grad + self.assertEqual(ref_param.grad, full_grad) + + ref_optim.step() + optim.step() + ref_optim.zero_grad() + optim.zero_grad() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + @skip_if_lt_x_gpu(2) @skipIfRocm def test_tp_with_fsdp_offloading(self): @@ -335,6 +399,57 @@ def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): self.assertEqual(loss_no_cp2, loss_cp2) +class TestFullyShard2DStateDict(DTensorTestBase): + @property + def backend(self): + # need to specify gloo backend for testing cpu offload + return "cpu:gloo,cuda:nccl" + + @with_comms + @skip_if_lt_x_gpu(4) + def test_fully_shard_tp_2d_set_full_state_dict(self): + dummy_model = SimpleModel().cuda() + mesh_2d = init_device_mesh( + "cuda", + (2, self.world_size // 2), + mesh_dim_names=("dp", "tp"), + ) + tp_mesh = mesh_2d["tp"] + dp_mesh = mesh_2d["dp"] + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + "net3": ColwiseParallel(), + } + model = parallelize_module(dummy_model, tp_mesh, parallelize_plan) + fully_shard(model, mesh=dp_mesh) + optim = torch.optim.Adam(model.parameters(), lr=0.01) + model(model.get_input()).sum().backward() + optim.step() + # ref_msd, ref_osd are both the default sharded state dict + ref_msd = copy.deepcopy(get_model_state_dict(model)) + ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim)) + + options = StateDictOptions( + full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True + ) + full_msd = get_model_state_dict(model, options=options) + full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options) + # load full_msd and full_osd into model and optim. + # this loads the slice of full tensor into each rank's local DTensor. + set_model_state_dict(model, full_msd, options=options) + set_optimizer_state_dict( + model, optimizers=optim, optim_state_dict=full_osd, options=options + ) + + # check after setting full state dict, the model and optim default sharded state dict + # are the same as the initial default sharded state dict. + new_msd = get_model_state_dict(model) + new_osd = get_optimizer_state_dict(model, optimizers=optim) + self.assertEqual(ref_msd, new_msd) + self.assertEqual(ref_osd, new_osd) + + class Test2dFSDP1ParallelIntegration(DTensorTestBase): def init_model(self, device_type, model_parallel_size=2): torch.manual_seed(0) @@ -544,6 +659,11 @@ def test_2d_e2e_training_not_use_orig_params(self): # TODO: update all state dict unit tests to use distributed.checkpoint.state_dict, # and consolidate all the state_dict test in test.distributed.checkpoint. class TestNew2dParallelStateDict(DTensorTestBase): + @property + def backend(self): + # need to specify gloo backend for testing cpu offload + return "cpu:gloo,cuda:nccl" + @with_comms @skip_if_lt_x_gpu(4) def test_fsdp_2d_extension(self): @@ -770,6 +890,81 @@ def test_2d_optim_state_dict(self, is_even_sharded_model): else: self.assertEqual(new_state, state) + @with_comms + @with_temp_dir + @skip_if_lt_x_gpu(4) + def test_fsdp1_tp_2d_set_full_state_dict(self): + """ + This is a workaround for loading full state dict into a FSDP1+TP 2D model. + Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict + and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: + 1) load the full state dict into a 1D FSDP model + 2) dcp.save the full/shard state dict into storage + 3) initialize a 2D FSDP1+TP model + 4) get the default sharded state dict for the 2D model (full_state_dict=False) + 5) dcp.load the state dict from storage + 6) load the state dict into the 2D model + """ + dummy_model = SimpleModel().cuda() + mesh_1d = init_device_mesh("cuda", (self.world_size,)) + model = FSDP(dummy_model, device_mesh=mesh_1d) + optim = torch.optim.Adam(model.parameters(), lr=0.01) + model(model.get_input()).sum().backward() + optim.step() + ref_full_msd = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + ref_full_osd = get_optimizer_state_dict( + model, + optimizers=optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + state_dict = {"model": ref_full_msd, "optim": ref_full_osd} + # save the full state dict into storage first + dcp.save(state_dict, checkpoint_id=self.temp_dir) + + # initialize 2d model + dummy_model = SimpleModel().cuda() + mesh_2d = init_device_mesh( + "cuda", + (2, self.world_size // 2), + mesh_dim_names=("dp", "tp"), + ) + tp_mesh = mesh_2d["tp"] + dp_mesh = mesh_2d["dp"] + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + "net3": ColwiseParallel(), + } + model_2d = parallelize_module(dummy_model, tp_mesh, parallelize_plan) + model_2d = FSDP(model_2d, device_mesh=dp_mesh, use_orig_params=True) + optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.01) + # get the default sharded state dict for model_2d + # note this is because we can not set full_state_dict back to 2D directly + msd = get_model_state_dict(model_2d) + osd = get_optimizer_state_dict(model_2d, optimizers=optim_2d) + state_dict = {"model": msd, "optim": osd} + dcp.load(state_dict=state_dict, checkpoint_id=self.temp_dir) + + set_model_state_dict(model_2d, state_dict["model"]) + set_optimizer_state_dict( + model_2d, optimizers=optim_2d, optim_state_dict=state_dict["optim"] + ) + + # check after setting sharded state dict, the model and optim full state dict + # are the same as the initial full state dict. + new_full_msd = get_model_state_dict( + model, options=StateDictOptions(full_state_dict=True, cpu_offload=True) + ) + new_full_osd = get_optimizer_state_dict( + model, + optimizers=optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + self.assertEqual(ref_full_msd, new_full_msd) + self.assertEqual(ref_full_osd, new_full_osd) + instantiate_parametrized_tests(TestNew2dParallelStateDict) diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index 328f677be7fcfe..e173bb34e3eaa2 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -17,6 +17,7 @@ ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, ) from torch.nn.parallel import DistributedDataParallel as DDP @@ -86,6 +87,7 @@ def device(self): ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleFlexibleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ], ) @parametrize("use_new_runtime", [False, True]) @@ -233,7 +235,9 @@ def build_stage(stage_idx, num_stages): name = ".".join(parts) ref_p = ref_parameters[name] self.assertTrue(isinstance(p.grad, DTensor)) - self.assertEqual(ref_p.grad, p.grad.full_tensor()) + torch.testing.assert_close( + ref_p.grad, p.grad.full_tensor(), rtol=1e-5, atol=5e-5 + ) elif dp_type == "DDP": for partial_model, offset in zip(partial_models, offsets): for name, p in partial_model.named_parameters(): @@ -241,7 +245,7 @@ def build_stage(stage_idx, num_stages): parts[0] = str(int(parts[0]) + offset) name = ".".join(parts) ref_p = ref_parameters[name] - self.assertEqual(ref_p.grad, p.grad) + torch.testing.assert_close(ref_p.grad, p.grad, rtol=1e-5, atol=5e-5) torch.distributed.destroy_process_group() diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 41ce99090018b0..e0df5aaff27f84 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -29,9 +29,10 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.distributed.fake_pg import FakeStore from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint @@ -216,21 +217,21 @@ def test_compile_cpu_no_sync(self): self._test_compile(use_gpu=False, no_sync=True) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch(reorder_for_locality=False) def test_compile_gpu_ac(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_bf16(self): def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -244,7 +245,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: self._test_compile(use_gpu=True, no_sync=False, setup_func=setup) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_fp16(self): def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -261,7 +262,7 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: ) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True) @@ -367,35 +368,28 @@ def test_bucketing_concat_op(self): fc.run(code) -class DDP_TP_Test(MultiProcessInductorTestCase): - @property - def world_size(self) -> int: - return min(4, torch.cuda.device_count()) +class DDP_TP_Test(InductorTestCase): + def setUp(self): + self.rank = 0 + self.world_size = 4 + torch.cuda.set_device("cuda:0") - def setUp(self) -> None: - super().setUp() - self._spawn_processes() + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) def tearDown(self): - super().tearDown() - try: - os.remove(self.file_name) - except OSError: - pass + dist.destroy_process_group() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm - @skip_if_lt_x_gpu(4) + @skipIfRocm def test_ddp_tp(self): - torch.cuda.set_device(f"cuda:{self.rank}") - dist.init_process_group( - backend="nccl", - rank=self.rank, - world_size=self.world_size, - store=dist.FileStore(self.file_name, self.world_size), - ) - model = Net().cuda() - compiled_replicate_model = deepcopy(model) + ref_model = Net() + compiled_replicate_model = deepcopy(ref_model) mesh_2d = init_device_mesh( "cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp") ) @@ -407,8 +401,8 @@ def test_ddp_tp(self): "fc3": ColwiseParallel(), "fc4": RowwiseParallel(), } - model = parallelize_module(model, tp_mesh, parallelize_plan) - model = replicate(model, device_mesh=dp_mesh) + ref_model = parallelize_module(ref_model, tp_mesh, parallelize_plan) + ref_model = replicate(ref_model, device_mesh=dp_mesh) compiled_replicate_model = parallelize_module( compiled_replicate_model, tp_mesh, parallelize_plan ) @@ -416,15 +410,23 @@ def test_ddp_tp(self): compiled_replicate_model, device_mesh=dp_mesh ) compiled_replicate_model = torch.compile(compiled_replicate_model) - data = torch.randn([1, DIM]).cuda() + data = torch.randn([1, DIM]) with compiled_autograd.enable(compiler_fn()): loss = compiled_replicate_model(data).sum() - loss.backward() + # TODO: We need "pre-dispatch tracing of backward graph" to make this work: + # https://github.com/pytorch/pytorch/issues/127797#issuecomment-2291695474 + with self.assertRaisesRegex( + AssertionError, + "Expected ProxyTensor, got ", + ): + loss.backward() - loss = model(data).sum() - loss.backward() - for p1, p2 in zip(model.parameters(), compiled_replicate_model.parameters()): - self.assertEqual(p1.grad, p2.grad) + # ref_loss = ref_model(data).sum() + # ref_loss.backward() + # for p1, p2 in zip( + # ref_model.parameters(), compiled_replicate_model.parameters() + # ): + # self.assertEqual(p1.grad, p2.grad) if __name__ == "__main__": diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py index 7d5f56c30118ce..3428bca2c83ba3 100644 --- a/test/distributed/_tensor/debug/test_comm_mode.py +++ b/test/distributed/_tensor/debug/test_comm_mode.py @@ -5,8 +5,8 @@ import torch.distributed._functional_collectives as funcol import torch.nn as nn from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor.debug.comm_mode import CommDebugMode from torch.distributed._tensor.placement_types import Shard +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import requires_nccl from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule diff --git a/test/distributed/_tensor/debug/test_comm_mode_features.py b/test/distributed/_tensor/debug/test_comm_mode_features.py index 05469b714cf57b..fc19cddb58f4a3 100644 --- a/test/distributed/_tensor/debug/test_comm_mode_features.py +++ b/test/distributed/_tensor/debug/test_comm_mode_features.py @@ -6,7 +6,7 @@ import torch from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor.api import distribute_tensor, DTensor -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/_tensor/debug/test_op_coverage.py b/test/distributed/_tensor/debug/test_op_coverage.py index 775a9f0d525d98..0392bae03c504a 100644 --- a/test/distributed/_tensor/debug/test_op_coverage.py +++ b/test/distributed/_tensor/debug/test_op_coverage.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from torch.distributed._tensor.debug._op_coverage import get_inductor_decomp_graphs +from torch.distributed.tensor.debug._op_coverage import get_inductor_decomp_graphs from torch.testing._internal.common_utils import run_tests, TestCase diff --git a/test/distributed/_tensor/experimental/test_local_map.py b/test/distributed/_tensor/experimental/test_local_map.py index b483194d6c3a82..85d9977310c1d4 100644 --- a/test/distributed/_tensor/experimental/test_local_map.py +++ b/test/distributed/_tensor/experimental/test_local_map.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] +from functools import partial import torch import torch.distributed._functional_collectives as funcol @@ -10,8 +11,8 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.experimental import local_map +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -22,6 +23,11 @@ funcol_py = torch.ops.c10d_functional +row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh +col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh +replicate = [Replicate()] # replicate placements on 1-d mesh + + def equal_allgather_forward(device_mesh, X, Y): eq = torch.tensor([torch.equal(X, Y)], device=X.device) eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) @@ -42,6 +48,16 @@ def mm_allreduce_forward(device_mesh, A, B): return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() +@partial( + local_map, + out_placements=replicate, + in_placements=(None, col_wise, row_wise), +) +def mm_allreduce_forward_decorated(device_mesh, A, B): + partial_sum_tensor = torch.mm(A, B) + return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() + + def mul_forward(X, scalar): # no device mesh needed since we don't do collective return torch.mul(X, scalar) @@ -59,20 +75,19 @@ def test_local_map_correctness(self): ) comm_mode = CommDebugMode() - # Y = W @ X - W = torch.randn(12, 8, device=self.device_type, requires_grad=False) - X = torch.randn(8, 16, device=self.device_type, requires_grad=False) - Y = torch.mm(W, X) + # Y = X @ W + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + Y = torch.mm(X, W) - row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh - col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh - replicate = [Replicate()] - W_dt = distribute_tensor( - W, device_mesh, col_wise - ) # col-wisely sharded W tensor X_dt = distribute_tensor( - X, device_mesh, row_wise - ) # row-wisely sharded X tensor + X, device_mesh, col_wise + ) # col-wisely sharded X tensor + W_dt = distribute_tensor( + W, device_mesh, row_wise + ) # row-wisely sharded W tensor + + # Test 1: use the function returned from calling local_map # get the function wrapped with DTensor/Tensor convertion # mm_allreduce_forward is a function that applies to Tensors with manual collective # local_mm_allreduce_forward is the function that does the same but applies to @@ -84,7 +99,19 @@ def test_local_map_correctness(self): device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) + + # output redistribution to Replicate + self.assertEqual(comm_mode.get_total_counts(), 1) + # check output placements + for placement in Y_dt.placements: + self.assertTrue(placement.is_replicate()) + # check output value + self.assertEqual(Y_dt.to_local(), Y) + + # Test 2: use the local_map decorator + with comm_mode: + Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt) # output redistribution to Replicate self.assertEqual(comm_mode.get_total_counts(), 1) @@ -106,7 +133,6 @@ def test_local_map_out_placements(self): # X.equal(Y) X = torch.randn(8, 8, device=self.device_type, requires_grad=False) Y = torch.randn(8, 8, device=self.device_type, requires_grad=False) - row_wise = [Shard(0)] X_dt = distribute_tensor(X, device_mesh, row_wise) Y_dt = distribute_tensor(Y, device_mesh, row_wise) local_equal_allgather_forward = local_map( @@ -122,7 +148,6 @@ def test_local_map_out_placements(self): # Test 2: directly return out if no argument is DTensor # matmul in DDP - replicate = [Replicate()] X = torch.randn( 4 // self.world_size, 4, device=self.device_type, requires_grad=False ) @@ -151,17 +176,15 @@ def test_local_map_in_placements(self): ) comm_mode = CommDebugMode() - # Y = W @ X - W = torch.randn(12, 8, device=self.device_type, requires_grad=False) - X = torch.randn(8, 16, device=self.device_type, requires_grad=False) - Y = torch.mm(W, X) + # Y = X @ W + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + Y = torch.mm(X, W) - row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh - replicate = [Replicate()] # replicate placements on 1-d mesh - W_dt = distribute_tensor( - W, device_mesh, row_wise - ) # row-wisely sharded W tensor - X_dt = distribute_tensor(X, device_mesh, replicate) # replicate X tensor + X_dt = distribute_tensor( + X, device_mesh, row_wise + ) # row-wisely sharded X tensor + W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor # Test 1: explicitly pass `in_placements` local_mm_forward = local_map( @@ -171,7 +194,7 @@ def test_local_map_in_placements(self): device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_forward(W_dt, X_dt) + Y_dt = local_mm_forward(X_dt, W_dt) # no communication should occur in this case self.assertEqual(comm_mode.get_total_counts(), 0) @@ -186,7 +209,7 @@ def test_local_map_in_placements(self): device_mesh=device_mesh, ) with comm_mode: - Y_dt = local_mm_forward(W_dt, X_dt) + Y_dt = local_mm_forward(X_dt, W_dt) self.assertEqual(comm_mode.get_total_counts(), 0) for placement in Y_dt.placements: @@ -194,15 +217,16 @@ def test_local_map_in_placements(self): self.assertEqual(Y_dt.full_tensor(), Y) # Test 3: `None` placements for non-Tensor input argument + # Y = X * 2.0 local_mul_forward = local_map( mul_forward, in_placements=(row_wise, None), out_placements=row_wise, device_mesh=device_mesh, ) - Y = torch.mul(W, 2.0) + Y = torch.mul(X, 2.0) with comm_mode: - Y_dt = local_mul_forward(W_dt, 2.0) + Y_dt = local_mul_forward(X_dt, 2.0) self.assertEqual(comm_mode.get_total_counts(), 0) for placement in Y_dt.placements: @@ -210,12 +234,6 @@ def test_local_map_in_placements(self): self.assertEqual(Y_dt.full_tensor(), Y) # Test 4: `None` placements for Tensor input argument - X = torch.randn(16, 8, device=self.device_type, requires_grad=False) - W = torch.randn(8, 12, device=self.device_type, requires_grad=False) - X_dt = distribute_tensor( - X, device_mesh, row_wise - ) # row-wisely sharded X tensor - W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor local_mm_forward = local_map( mm_forward, out_placements=None, @@ -265,20 +283,17 @@ def test_local_map_redistribute(self): ) comm_mode = CommDebugMode() - # Y = W @ X - W = torch.randn(12, 8, device=self.device_type, requires_grad=False) - X = torch.randn(8, 16, device=self.device_type, requires_grad=False) - Y = torch.mm(W, X) + # Y = X @ W + X = torch.randn(16, 8, device=self.device_type, requires_grad=False) + W = torch.randn(8, 12, device=self.device_type, requires_grad=False) + Y = torch.mm(X, W) - row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh - col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh - replicate = [Replicate()] - W_dt = distribute_tensor( - W, device_mesh, row_wise - ) # row-wisely sharded W tensor which will be redistributed X_dt = distribute_tensor( - X, device_mesh, col_wise - ) # col-wisely sharded X tensor which will be redistributed + X, device_mesh, row_wise + ) # row-wisely sharded X tensor which will be redistributed + W_dt = distribute_tensor( + W, device_mesh, col_wise + ) # col-wisely sharded W tensor which will be redistributed # Test 1: allow input redistribution local_mm_allreduce_forward = local_map( @@ -289,7 +304,7 @@ def test_local_map_redistribute(self): redistribute_inputs=True, ) with comm_mode: - Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) # 2 for input redistribution and 1 for output self.assertEqual(comm_mode.get_total_counts(), 3) @@ -306,7 +321,7 @@ def test_local_map_redistribute(self): redistribute_inputs=False, ) with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): - Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) + Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt) if __name__ == "__main__": diff --git a/test/distributed/_tensor/experimental/test_tp_transform.py b/test/distributed/_tensor/experimental/test_tp_transform.py index 18a322710825ef..719b3a21c08966 100644 --- a/test/distributed/_tensor/experimental/test_tp_transform.py +++ b/test/distributed/_tensor/experimental/test_tp_transform.py @@ -3,7 +3,7 @@ from typing import Dict import torch -from torch.distributed._tensor.experimental.tp_transform import ( +from torch.distributed._tensor.experimental._tp_transform import ( tensor_parallel_transformation, ) from torch.distributed.tensor.parallel.style import ( diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index 0f18d54697021a..238e551ab6a4aa 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -7,14 +7,15 @@ import torch.nn.functional as F from torch import nn from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.experimental.attention import ( +from torch.distributed._tensor.experimental._attention import ( _AttentionContextParallel, _CausalBehavior, - _context_parallel_buffers, + _cp_options, _is_causal_behavior, context_parallel, + context_parallel_unshard, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing._internal.common_cuda import ( @@ -24,6 +25,7 @@ ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( + decorateIf, instantiate_parametrized_tests, parametrize, run_tests, @@ -57,11 +59,19 @@ def world_size(self) -> int: "Does not support flash nor efficient attention", ) @with_comms + @decorateIf( + unittest.skip, lambda params: params["load_balance"] and not params["is_causal"] + ) @parametrize("is_causal", [True, False]) @parametrize("compiled", [True, False]) @parametrize("backend", backends) + @parametrize("load_balance", [True, False]) def test_ring_attention_sdpa( - self, is_causal: bool, compiled: bool, backend: SDPBackend + self, + is_causal: bool, + compiled: bool, + backend: SDPBackend, + load_balance: bool, ) -> None: device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size)) dtype = torch.bfloat16 @@ -79,6 +89,8 @@ def test_ring_attention_sdpa( # TODO: Fix this after we move `wait_tensor` to use `with_effect`. return + _cp_options.enable_load_balance = load_balance + q = torch.rand( (bs, nheads, self.world_size * query_tokens, dim), device=self.device_type, @@ -108,12 +120,6 @@ def test_ring_attention_sdpa( out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) out.sum().backward() - local_out, local_dq, local_dk, local_dv = _context_parallel_buffers( - device_mesh, - buffers=(out, q.grad, k.grad, v.grad), - buffer_seq_dims=(2, 2, 2, 2), - ) - cp_q = q.clone().detach() cp_k = k.clone().detach() cp_v = v.clone().detach() @@ -154,21 +160,26 @@ def test_ring_attention_sdpa( # Due to numerical error, we need to choose different atol for different # attention kernels + cp_out, cp_dq, cp_dk, cp_dv = context_parallel_unshard( + device_mesh, + [cp_out, cp_q.grad, cp_k.grad, cp_v.grad], + [2, 2, 2, 2], + ) atol = ( 1e-08 if backend == SDPBackend.EFFICIENT_ATTENTION else 1e-3 * self.world_size ) - self.assertTrue(torch.allclose(local_out, cp_out, atol=atol)) + self.assertTrue(torch.allclose(out, cp_out, atol=atol)) atol = ( 2e-06 if backend == SDPBackend.EFFICIENT_ATTENTION else 8e-3 * self.world_size ) - self.assertTrue(torch.allclose(local_dq, cp_q.grad, atol=atol)) - self.assertTrue(torch.allclose(local_dk, cp_k.grad, atol=atol)) - self.assertTrue(torch.allclose(local_dv, cp_v.grad, atol=atol)) + self.assertTrue(torch.allclose(q.grad, cp_dq, atol=atol)) + self.assertTrue(torch.allclose(k.grad, cp_dk, atol=atol)) + self.assertTrue(torch.allclose(v.grad, cp_dv, atol=atol)) cp_q.grad = None cp_k.grad = None @@ -178,6 +189,7 @@ def test_ring_attention_sdpa( cp_v.requires_grad = False def test_is_causal_behavior(self) -> None: + _cp_options.enable_load_balance = False self.assertEqual( _is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False), _CausalBehavior.NOT_IS_CAUSAL, @@ -194,6 +206,18 @@ def test_is_causal_behavior(self) -> None: behavior, ) + _cp_options.enable_load_balance = True + ranks = [ + [_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL], + [_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL], + ] + for rank, iters in enumerate(ranks): + for i, behavior in enumerate(iters): + self.assertEqual( + _is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True), + behavior, + ) + @skip_if_lt_x_gpu(2) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" @@ -202,6 +226,7 @@ def test_is_causal_behavior(self) -> None: @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) @parametrize("is_causal", [True, False]) def test_ring_attention_native_transformer(self, is_causal: bool) -> None: + _cp_options.enable_load_balance = is_causal device_mesh = DeviceMesh( self.device_type, torch.arange(0, self.world_size), diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py index 77b5d91405a73a..f712ef1cf02bad 100644 --- a/test/distributed/_tensor/test_common_rules.py +++ b/test/distributed/_tensor/test_common_rules.py @@ -3,9 +3,9 @@ import torch from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor._op_schema import OpSchema -from torch.distributed._tensor.ops._common_rules import einop_rule, pointwise_rule from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema +from torch.distributed.tensor._ops._common_rules import einop_rule, pointwise_rule from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py index 14f92115e7637b..bc94a4c859a68d 100644 --- a/test/distributed/_tensor/test_dtensor.py +++ b/test/distributed/_tensor/test_dtensor.py @@ -14,7 +14,7 @@ DTensor, init_device_mesh, ) -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed._tensor.experimental import implicit_replication from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, @@ -22,6 +22,7 @@ Shard, TensorMeta, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -778,8 +779,6 @@ def test_implicit_replication(self): local_tensor1 = torch.ones(4, 3) sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) - from torch.distributed._tensor.experimental import implicit_replication - with implicit_replication(): # We put the scalar tensor as the left operand so we can test out # when a non-dtensor is a the arg in the args list. @@ -816,6 +815,41 @@ def add_scalar_tensor_with_dtensor(): (numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor ) + @with_comms + def test_implicit_replication_for_foreach_ops(self): + mesh = init_device_mesh( + self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + ) + global_tensor1 = torch.randn(4, 2) + dtensor_2d = distribute_tensor(global_tensor1, mesh, [Shard(0), Shard(1)]) + self.assertEqual(dtensor_2d.full_tensor(), global_tensor1) + global_tensor2 = torch.randn(4) + dtensor_1d = distribute_tensor(global_tensor2, mesh["dp"], [Shard(0)]) + dtensor_list = [dtensor_2d, dtensor_1d] + + # Check without implicit replication, cross mesh error raises. + with self.assertRaisesRegex( + RuntimeError, "DTensor does not support cross-mesh operation yet!" + ): + torch._foreach_mul(dtensor_list, 2.0) + + # Check dtensor result matches tensor result. + with implicit_replication(): + torch._foreach_mul_(dtensor_list, 2.0) + self.assertEqual(dtensor_list[0].full_tensor(), global_tensor1 * 2.0) + self.assertEqual(dtensor_list[1].full_tensor(), global_tensor2 * 2.0) + + mesh_1d = DeviceMesh.from_group(mesh["tp"].get_group(), self.device_type) + dtensor_1d = distribute_tensor(global_tensor2, mesh_1d, [Shard(0)]) + dtensor_list = [dtensor_2d, dtensor_1d] + + # Check even with implicit replication, cross mesh error raises if different device mesh don't + # belong to the same root mesh. + with self.assertRaisesRegex( + RuntimeError, "DTensor does not support cross-mesh operation yet!" + ): + torch._foreach_mul_(dtensor_list, 2.0) + @with_comms def test_metadata_consistency_check(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -909,7 +943,7 @@ def test_split_tensor_1D(self) -> None: ] assert_array_equal(expected_pad_sizes, pad_sizes) - from torch.distributed._tensor._collective_utils import unpad_tensor + from torch.distributed.tensor._collective_utils import unpad_tensor unpadded_list = [ unpad_tensor(tensor, shard_placement.dim, pad_sizes[i]) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index a58ab232748c83..3f4ddfce7813f2 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -176,6 +176,25 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + def test_dtensor_dynamic(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in DTensor as inputs/outputs and run some tensor computation + def fn(x): + return ( + torch.mul(x, x) + .redistribute(device_mesh=x.device_mesh, placements=[Replicate()]) + .to_local()[0] + ) + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + torch._dynamo.mark_dynamic(x, 0) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(res, ref) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -299,6 +318,70 @@ def from_local_kwargs_fn(x): self.assertEqual(res, ref) self.assertEqual(cnt.frame_count, 2) + def test_dynamo_dtensor_from_local_dynamic_shapes(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # Case 1: all dims dynamic + def fn(x): + dt = DTensor.from_local( + x, + mesh, + [Replicate()], + run_check=False, + shape=x.shape, + stride=x.stride(), + ) + return dt.to_local() + 2 + + inp = torch.randn(4, 6, requires_grad=True) + ref = fn(inp) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + res = torch.compile(fn, backend=cnt, fullgraph=True, dynamic=True)(inp) + res.sum().backward() + + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) + + # Case 2: only sizes are dynamic, strides are static + def fn(x): + dt = DTensor.from_local( + x, mesh, [Replicate()], run_check=False, shape=x.shape, stride=(1,) + ) + return dt.to_local() + 2 + + inp = torch.randn(4, requires_grad=True) + torch._dynamo.mark_dynamic(inp, 0) + ref = fn(inp) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + res = torch.compile(fn, backend=cnt, fullgraph=True)(inp) + res.sum().backward() + + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) + + # Case 3: both sizes and strides have a mix of dynamic and static dims + def fn(x): + dt = DTensor.from_local( + x, + mesh, + [Replicate()], + run_check=False, + shape=(x.shape[0], x.shape[1], 2), + stride=(x.stride()[0], 2, 1), + ) + return dt.to_local() + 2 + + inp = torch.randn(4, 6, 2, requires_grad=True) + torch._dynamo.mark_dynamic(inp, 0) + torch._dynamo.mark_dynamic(inp, 1) + ref = fn(inp) + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + res = torch.compile(fn, backend=cnt, fullgraph=True)(inp) + res.sum().backward() + + self.assertEqual(res, ref) + self.assertEqual(cnt.frame_count, 1) + def test_dynamo_dtensor_recompile(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index d13ec1724983d6..532bd3facae554 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -191,9 +191,6 @@ def wrapped(fn): xfail("index_reduce", "amin"), xfail("index_select"), xfail("isin"), - xfail("isinf"), - xfail("isneginf"), - xfail("isposinf"), xfail("kthvalue"), xfail("linalg.cholesky"), xfail("linalg.cholesky_ex"), @@ -307,7 +304,6 @@ def wrapped(fn): xfail("nn.functional.elu"), xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.fractional_max_pool3d"), - xfail("nn.functional.gaussian_nll_loss"), xfail("nn.functional.glu"), xfail("nn.functional.grid_sample"), xfail("nn.functional.group_norm"), @@ -353,7 +349,6 @@ def wrapped(fn): xfail("nn.functional.pdist"), xfail("nn.functional.pixel_shuffle"), xfail("nn.functional.pixel_unshuffle"), - xfail("nn.functional.poisson_nll_loss"), xfail("nn.functional.prelu"), xfail("nn.functional.relu6"), xfail("nn.functional.rrelu"), diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py index 3c366d570b37cc..889f62a8f22309 100644 --- a/test/distributed/_tensor/test_embedding_ops.py +++ b/test/distributed/_tensor/test_embedding_ops.py @@ -10,7 +10,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -167,7 +167,7 @@ def test_sharded_embedding_rowwise(self): self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22) self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10) - from torch.distributed._tensor.ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial # test collectives embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type) @@ -191,7 +191,7 @@ def test_multiple_embeddings_rowwise(self): inp = torch.randint(0, 10, (4, 4), device=self.device_type) replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False) - from torch.distributed._tensor.ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial # and MaskBuffer, because of cache hit from sharding propagation diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py index f6fa076b5b9069..d40b0571eb6947 100644 --- a/test/distributed/_tensor/test_math_ops.py +++ b/test/distributed/_tensor/test_math_ops.py @@ -13,9 +13,9 @@ distribute_tensor, DTensor, ) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.ops.utils import is_tensor_partial, normalize_dim from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py index 45988e233248f0..40241917bd7c52 100644 --- a/test/distributed/_tensor/test_matrix_ops.py +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -8,13 +8,13 @@ import torch.nn.functional as F from torch.distributed._tensor import DeviceMesh, distribute_tensor from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_op_strategy.py b/test/distributed/_tensor/test_op_strategy.py index 302e2675cc8995..86d5b92075eab9 100644 --- a/test/distributed/_tensor/test_op_strategy.py +++ b/test/distributed/_tensor/test_op_strategy.py @@ -4,12 +4,6 @@ import torch from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy -from torch.distributed._tensor.ops._einsum_strategy import ( - EinsumDims, - gen_einsum_strategies, -) from torch.distributed._tensor.placement_types import ( DTensorSpec, Partial, @@ -17,6 +11,12 @@ Shard, TensorMeta, ) +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._op_schema import OpSchema, OpStrategy, PlacementStrategy +from torch.distributed.tensor._ops._einsum_strategy import ( + EinsumDims, + gen_einsum_strategies, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase @@ -169,7 +169,7 @@ def test_redistribute_cost_mesh_1d(self): def test_redistribute_cost_latency(self): # test cost model on addmm op - from torch.distributed._tensor.ops._matrix_ops import addmm_strategy + from torch.distributed.tensor._ops._matrix_ops import addmm_strategy mesh = self.build_device_mesh() shard0_placement = (Shard(0),) @@ -246,7 +246,7 @@ def test_redistribute_cost_mesh_2d(self): self.assertTrue(allreduce_cost > reduce_scatter_cost) def test_mm_strategies(self): - from torch.distributed._tensor.ops._matrix_ops import mm_strategy + from torch.distributed.tensor._ops._matrix_ops import mm_strategy mesh = self.build_device_mesh() lhs_tensor = torch.randn(6, 8) @@ -292,7 +292,7 @@ def test_mm_strategies(self): self.assertFalse(output_sharding.needs_redistribute) def test_bmm_strategies(self): - from torch.distributed._tensor.ops._matrix_ops import bmm_strategy + from torch.distributed.tensor._ops._matrix_ops import bmm_strategy mesh = self.build_device_mesh() lhs_tensor = torch.randn(8, 6, 8) diff --git a/test/distributed/_tensor/test_random_ops.py b/test/distributed/_tensor/test_random_ops.py index c149075eddd989..6964b412537a23 100644 --- a/test/distributed/_tensor/test_random_ops.py +++ b/test/distributed/_tensor/test_random_ops.py @@ -5,13 +5,13 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.random as random +import torch.distributed.tensor._random as random from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed._tensor.api import distribute_tensor from torch.distributed._tensor.placement_types import Replicate, Shard -from torch.distributed._tensor.random import is_rng_supported_mesh, manual_seed from torch.distributed.distributed_c10d import broadcast_object_list +from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py index 5634869155052c..6a940d3c11e8cf 100644 --- a/test/distributed/_tensor/test_redistribute.py +++ b/test/distributed/_tensor/test_redistribute.py @@ -5,10 +5,10 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor -from torch.distributed._tensor._collective_utils import shard_dim_alltoall -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor._collective_utils import shard_dim_alltoall +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -207,7 +207,7 @@ def test_replicate_to_partial(self): with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"): partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) - from torch.distributed._tensor._redistribute import Redistribute + from torch.distributed.tensor._redistribute import Redistribute comm_mode = CommDebugMode() diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py index dde1a515b5da1e..f9153c126bc8e9 100644 --- a/test/distributed/_tensor/test_tensor_ops.py +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -3,8 +3,8 @@ import torch from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -445,7 +445,7 @@ def test_gather(self): # case 2 input sharding: input sharded, index replicated, output mask partial # only works when index has size 1 on the gather dimension and # input is sharded on the gather dimension - from torch.distributed._tensor.ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial gather_dim = 1 global_input = torch.randn(12, 8, 16) @@ -614,7 +614,7 @@ def test_dtensor_dtype_conversion(self): self.assertEqual(bf16_sharded_dtensor1.dtype, torch.bfloat16) self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16) - from torch.distributed._tensor.debug import _get_sharding_prop_cache_info + from torch.distributed.tensor.debug import _get_sharding_prop_cache_info # by this point we only have cache misses hits, misses, _, _ = _get_sharding_prop_cache_info() diff --git a/test/distributed/_tensor/test_utils.py b/test/distributed/_tensor/test_utils.py index 3bad008fb806eb..e41c990f21a583 100644 --- a/test/distributed/_tensor/test_utils.py +++ b/test/distributed/_tensor/test_utils.py @@ -4,19 +4,11 @@ import torch from torch.distributed._tensor import distribute_tensor, DTensor -from torch.distributed._tensor._utils import ( - compute_local_shape, - compute_local_shape_and_global_offset, -) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import ( - _StridedShard, - DTensorSpec, - Replicate, - Shard, - TensorMeta, -) -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed._tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.debug import CommDebugMode +from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -32,48 +24,17 @@ class UtilTest(DTensorTestBase): def world_size(self): return 8 - @with_comms - def test_compute_local_shape_2d_uneven(self): - # mesh: 4 * 2 - mesh_tensor = torch.arange(self.world_size).reshape(4, 2) - mesh = DeviceMesh(self.device_type, mesh_tensor) - size = torch.Size([7, 7]) - rank_coordinates = mesh.get_coordinate() - - # replicate, shard - placements2 = [Replicate(), Shard(0)] - local_size2 = compute_local_shape(size, mesh, placements2) - if rank_coordinates[1] < 1: - self.assertEqual(local_size2, torch.Size([4, 7])) - else: - self.assertEqual(local_size2, torch.Size([3, 7])) - - # shard, shard - placements3 = [Shard(0), Shard(1)] - local_size3 = compute_local_shape(size, mesh, placements3) - # first dim - if rank_coordinates[0] < 3: - self.assertEqual(local_size3[0], 2) - else: - self.assertEqual(local_size3[0], 1) - # second dim - if rank_coordinates[1] < 1: - self.assertEqual(local_size3[1], 4) - else: - self.assertEqual(local_size3[1], 3) - @with_comms def test_compute_local_shape_and_global_offset_1D(self): one_d_placements = [[Shard(0)], [Replicate()]] + device_mesh = init_device_mesh(self.device_type, (self.world_size,)) for placements in one_d_placements: # When the placements is [Shard(0)], we test for three different scenarios: # 1) sharding resulting in empty shards on all or some of the ranks # 2) sharding resulting in shards of different size across different ranks # 3) sharding resulting in non-empty shards of same size across all ranks for size in range(self.world_size * 2 + 1): - mesh_tensor = torch.arange(self.world_size) - device_mesh = DeviceMesh(self.device_type, mesh_tensor) global_tensor = torch.arange(size) global_shape = global_tensor.size() @@ -101,12 +62,12 @@ def test_compute_local_shape_and_global_offset_2D(self): itertools.combinations_with_replacement(two_d_placements_options, 2) ) + # mesh: 2 * 4 + device_mesh = init_device_mesh(self.device_type, (2, 4)) for placements in two_d_placements: - for dim_0_size in (1, 2, 4, 8): - # mesh: 2 * 4 - mesh_tensor = torch.arange(self.world_size).reshape(2, 4) - device_mesh = DeviceMesh(self.device_type, mesh_tensor) - global_tensor = torch.arange(64).view(dim_0_size, -1) + for dim_0_size in range(1, 9): + nelem = 64 // dim_0_size * dim_0_size + global_tensor = torch.arange(nelem).view(dim_0_size, -1) global_shape = global_tensor.size() dtensor = distribute_tensor(global_tensor, device_mesh, placements) diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py index 8ace53d97131b2..630c7f8511d825 100644 --- a/test/distributed/_tensor/test_view_ops.py +++ b/test/distributed/_tensor/test_view_ops.py @@ -7,9 +7,15 @@ import torch import torch.distributed as dist from torch import rand, randn, Tensor -from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.ops._view_ops import ( +from torch.distributed._tensor import ( + DeviceMesh, + distribute_tensor, + init_device_mesh, + Replicate, + Shard, +) +from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor._ops._view_ops import ( Broadcast, dim_maps, Flatten, @@ -19,7 +25,7 @@ Split, view_groups, ) -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -29,6 +35,10 @@ class TestViewOps(DTensorTestBase): + @property + def world_size(self) -> int: + return 6 + def test_view_groups(self): self.assertEqual( view_groups([2, 3], [3, 2]), @@ -106,8 +116,8 @@ def test_view_groups(self): view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]), ( Flatten((InputDim(2), InputDim(3))), - Singleton(), - Singleton(), + InputDim(4), + InputDim(5), Singleton(), ), ) @@ -116,7 +126,7 @@ def test_view_groups(self): ( Split(InputDim(2), (3, 4), 0), Split(InputDim(2), (3, 4), 1), - Singleton(), + InputDim(3), Flatten((InputDim(6), InputDim(7))), ), ) @@ -125,10 +135,6 @@ def test_view_groups(self): (InputDim(0), InputDim(1), InputDim(2)), ) - @property - def world_size(self) -> int: - return 6 - def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): dim_map = dim_maps[op] rules = dim_map(*args, **kwargs) @@ -429,7 +435,7 @@ def test_view_ops(self): self.dimmap_test( Tensor.view, (randn(1, 1, 42, 1, 24, 1), -1), - (Flatten((InputDim(2), InputDim(4))),), + (Flatten((InputDim(2), InputDim(input_dim=3), InputDim(4))),), ) self.dimmap_test( @@ -525,6 +531,46 @@ def test_complex_view_ops(self): ) self.assertEqual(out, out_dt.full_tensor()) + @with_comms + def test_dtensor_view_op_uneven(self): + """ + Test two uneven cases for view op: + 1) the sharded tensor dim is 1 so that only the first rank has an non-empty shard. + 2) the sharded tensor dim is uneven such that some ranks have full shards, + smaller non-empty shards, and empty shards. + """ + dim0_sizes = [1, self.world_size + 1] + for dim0_size in dim0_sizes: + p = torch.randn(dim0_size, 2, 2, 2) + mesh = init_device_mesh(self.device_type, (self.world_size,)) + dtensor = distribute_tensor(p, mesh, [Shard(0)]) + + with CommDebugMode() as comm_mode: + view = dtensor.view(dim0_size, 2, 4) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + # when no communication happens, the data pointer should be the same. + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + + view = dtensor.view(dim0_size, 4, 2) + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + + view = dtensor.view(dim0_size, 8) + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + + view = dtensor.view(dtensor.shape) + self.assertEqual( + view.to_local().data_ptr(), dtensor.to_local().data_ptr() + ) + self.assertEqual(len(comm_mode.get_comm_counts()), 0) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_tools/test_runtime_estimator.py b/test/distributed/_tools/test_runtime_estimator.py new file mode 100644 index 00000000000000..400903f17673f5 --- /dev/null +++ b/test/distributed/_tools/test_runtime_estimator.py @@ -0,0 +1,197 @@ +# Owner(s): ["module: unknown"] +import unittest +from dataclasses import dataclass +from typing import Any, Callable, cast, Tuple, Union + +import torch +from torch import nn, optim +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, +) + + +@dataclass +class ConvArgs: + image_size: int + num_classes: int + + +class SimpleCNN(nn.Module): + def __init__(self, conv_args: ConvArgs): + super().__init__() + image_size = conv_args.image_size + num_classes = conv_args.num_classes + self.image_size = image_size + self.conv1 = nn.Conv2d(3, 32, kernel_size=5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) + self.conv3 = nn.Conv2d(64, 128, kernel_size=3) + self.conv4 = nn.Conv2d(128, 256, kernel_size=3) + self.fc1_size = self._calculate_fc1_size() + self.fc1 = nn.Linear(self.fc1_size, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, num_classes) + + def _calculate_fc1_size(self): + size = self.image_size + size = (size - 5 + 1) // 2 # conv1 and pool + size = (size - 5 + 1) // 2 # conv2 and pool + size = size - 3 + 1 # conv3 + size = (size - 3 + 1) // 2 # conv4 and pool + return 512 * size * size + + def forward(self, x): + x = self.pool(nn.functional.relu(self.conv1(x))) + x = self.pool(nn.functional.relu(self.conv2(x))) + x = nn.functional.relu(self.conv3(x)) + x = self.pool(nn.functional.relu(self.conv4(x))) + x = x.view(-1, self.fc1_size) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class TestRuntimeEstimator(TestCase): + def _train_step( + self, + model: nn.Module, + optimizer: optim.Optimizer, + inp: torch.Tensor, + ): + out = model(inp) + loss = out.sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + + def _measure_actual_cuda_time( + self, + func: Callable, + args: Tuple[Any, ...], + ) -> float: + warmup_iters, actual_iters = 2, 5 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for _ in range(warmup_iters): + func(*args) + start_event.record() + for _ in range(actual_iters): + func(*args) + end_event.record() + torch.cuda.synchronize() + measured_time = start_event.elapsed_time(end_event) / actual_iters + return measured_time + + def _runtime_estimate( + self, + estimate_mode: str, + func: Callable, + args: Tuple[Any, ...], + ) -> float: + # Optimizer init step + func(*args) + runtime_estimator = RuntimeEstimator() + with runtime_estimator(estimate_mode_type=estimate_mode): + func(*args) + return runtime_estimator.total_runtime + + def _init_model_and_args( + self, + model_type: str, + model_args: Union[ConvArgs, ModelArgs], + bsz: int, + ) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]: + dev = torch.cuda.current_device() + if model_type == "Transformer": + model_args = cast(ModelArgs, model_args) + with torch.device(dev): + model = Transformer(model_args) + optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True) + inp = torch.randint( + 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev + ) + elif model_type == "CNN": + model_args = cast(ConvArgs, model_args) + with torch.device(dev): + model = SimpleCNN(model_args) + optimizer = optim.SGD(model.parameters(), lr=1e-2, foreach=True) + inp = torch.randn( + bsz, 3, model_args.image_size, model_args.image_size, device=dev + ) + else: + raise NotImplementedError("Only Transformer and CNN is supported") + return (model, optimizer, inp) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_transformer_runtime( + self, + ): + """Runs a basic GPT-2 model""" + vocab_size = 8192 + bsz, seq_len = 8, 1024 + model_args = ModelArgs( + n_layers=4, + n_heads=12, + vocab_size=vocab_size, + max_seq_len=seq_len, + dim=768, + dropout_p=0.1, + ) + + args = self._init_model_and_args("Transformer", model_args, bsz) + actual_runtime = self._measure_actual_cuda_time(self._train_step, args) + with FakeTensorMode(): + fake_args = self._init_model_and_args("Transformer", model_args, bsz) + benchmark_estimate = self._runtime_estimate( + "operator-level-benchmark", self._train_step, fake_args + ) + roofline_estimate = self._runtime_estimate( + "operator-level-cost-model", self._train_step, fake_args + ) + benchmark_accuracy = actual_runtime / benchmark_estimate + roofline_accuracy = actual_runtime / roofline_estimate + print( + f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}" + f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + ) + self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) + self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_conv_model_runtime( + self, + ): + """Runs a simple CNN model""" + num_classes = 100 + bsz, img_sz = 256, 128 + model_args = ConvArgs(img_sz, num_classes) + args = self._init_model_and_args("CNN", model_args, bsz) + actual_runtime = self._measure_actual_cuda_time(self._train_step, args) + with FakeTensorMode(): + fake_args = self._init_model_and_args("CNN", model_args, bsz) + benchmark_estimate = self._runtime_estimate( + "operator-level-benchmark", self._train_step, fake_args + ) + roofline_estimate = self._runtime_estimate( + "operator-level-cost-model", self._train_step, fake_args + ) + benchmark_accuracy = actual_runtime / benchmark_estimate + roofline_accuracy = actual_runtime / roofline_estimate + print( + f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n" + f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}" + ) + self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) + self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/algorithms/quantization/test_quantization.py b/test/distributed/algorithms/quantization/test_quantization.py index 0915b103d241cd..94a1c763474501 100644 --- a/test/distributed/algorithms/quantization/test_quantization.py +++ b/test/distributed/algorithms/quantization/test_quantization.py @@ -14,7 +14,7 @@ requires_gloo, requires_nccl, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, ) from torch.testing._internal.common_utils import ( NO_MULTIPROCESSING_SPAWN, @@ -112,7 +112,7 @@ def test_all_gather_bfp16(self): BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_fp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( @@ -137,7 +137,7 @@ def test_all_to_all_fp16(self): BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16" ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_bfp16(self): store = dist.FileStore(self.file_name, self.world_size) dist.init_process_group( diff --git a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py index 887e46e2510bad..5ea55270936d7d 100644 --- a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py +++ b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: distributed"] +import contextlib import copy import torch @@ -7,6 +8,7 @@ import torch.nn as nn from torch.distributed._composable.fsdp import fully_shard from torch.distributed._tensor import DTensor, init_device_mesh +from torch.distributed._tensor.experimental import implicit_replication from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, get_optimizer_state_dict, @@ -439,8 +441,17 @@ def _get_base_model(mlp_dim: int = 2): self.assertEqual(base_osd, fsdp2_tp_full_osd) @skip_if_lt_x_gpu(4) - @with_temp_dir def test_save_with_fsdp2_tp_and_load_with_tp(self): + self.run_subtests( + {"allow_implicit_replication": [True, False]}, + self._test_save_with_fsdp2_tp_and_load_with_tp, + ) + + @skip_if_lt_x_gpu(4) + @with_temp_dir + def _test_save_with_fsdp2_tp_and_load_with_tp( + self, allow_implicit_replication: bool + ): """ Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP. """ @@ -449,6 +460,11 @@ def _get_base_model(mlp_dim: int = 2): base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim)) return base_model + cm = ( + implicit_replication() + if allow_implicit_replication + else contextlib.nullcontext() + ) tp_parallelize_plan = { "0.in_proj": ColwiseParallel(), "0.out_proj": RowwiseParallel(), @@ -457,108 +473,124 @@ def _get_base_model(mlp_dim: int = 2): "2.in_proj": ColwiseParallel(), "2.out_proj": RowwiseParallel(), } + if allow_implicit_replication: + # intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized + tp_parallelize_plan.pop("0.in_proj") + tp_parallelize_plan.pop("0.out_proj") - # init device mesh - dp_size = 2 - global_mesh_1d = init_device_mesh( - "cuda", (self.world_size,), mesh_dim_names=("tp",) - ) - global_mesh_2d = init_device_mesh( - "cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp") - ) - dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"] - - for save_full_state_dict in [True, False]: - # Save state dict with original model - base_model = _get_base_model().cuda() - base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) - - # Save state dict with FSDP2 + TP model - fsdp2_tp_model = copy.deepcopy(base_model) - fsdp2_tp_model = parallelize_module( - fsdp2_tp_model, - device_mesh=tp_mesh, - parallelize_plan=tp_parallelize_plan, - ) - for module in fsdp2_tp_model: - fully_shard(module, mesh=dp_mesh) - fully_shard(fsdp2_tp_model, mesh=dp_mesh) - fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) - - # one-step training to modify state dict - inp = torch.randn((2,), device=self.rank) - base_model(inp).sum().backward() - base_optim.step() - fsdp2_tp_model(inp).sum().backward() - fsdp2_tp_optim.step() - - # obtain the unsharded state dict - base_msd = get_model_state_dict( - base_model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - base_osd = get_optimizer_state_dict( - base_model, - base_optim, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - - # obtain FSDP2 + TP state dict - fsdp2_tp_msd = get_model_state_dict( - fsdp2_tp_model, - options=StateDictOptions(full_state_dict=save_full_state_dict), - ) - fsdp2_tp_osd = get_optimizer_state_dict( - fsdp2_tp_model, - fsdp2_tp_optim, - options=StateDictOptions(full_state_dict=save_full_state_dict), - ) - - fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd} - dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir) - - fsdp2_tp_full_msd = get_model_state_dict( - fsdp2_tp_model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - fsdp2_tp_full_osd = get_optimizer_state_dict( - fsdp2_tp_model, - fsdp2_tp_optim, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), - ) - - # Load state dict into model with TP applied - tp_model = _get_base_model() - tp_model = parallelize_module( - tp_model, - device_mesh=global_mesh_1d, - parallelize_plan=tp_parallelize_plan, - ) - tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1) - - tp_state_dict = { - "model": get_model_state_dict(tp_model), - "optim": get_optimizer_state_dict(tp_model, tp_optim), + with cm: + tp_parallelize_plan = { + "0.in_proj": ColwiseParallel(), + "0.out_proj": RowwiseParallel(), + "1.in_proj": ColwiseParallel(), + "1.out_proj": RowwiseParallel(), + "2.in_proj": ColwiseParallel(), + "2.out_proj": RowwiseParallel(), } - dcp.load(tp_state_dict, checkpoint_id=self.temp_dir) - tp_model.load_state_dict(tp_state_dict["model"]) - tp_optim.load_state_dict(tp_state_dict["optim"]) - tp_full_msd = get_model_state_dict( - tp_model, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), + # init device mesh + dp_size = 2 + global_mesh_1d = init_device_mesh( + "cuda", (self.world_size,), mesh_dim_names=("tp",) ) - tp_full_osd = get_optimizer_state_dict( - tp_model, - tp_optim, - options=StateDictOptions(full_state_dict=True, cpu_offload=True), + global_mesh_2d = init_device_mesh( + "cuda", + (dp_size, self.world_size // dp_size), + mesh_dim_names=("dp", "tp"), ) - - # Compare full state dict to make sure they are the same. - self.assertEqual(base_msd, tp_full_msd) - self.assertEqual(base_osd, tp_full_osd) - self.assertEqual(fsdp2_tp_full_msd, tp_full_msd) - self.assertEqual(fsdp2_tp_full_osd, tp_full_osd) + dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"] + + for save_full_state_dict in [True, False]: + # Save state dict with original model + base_model = _get_base_model().cuda() + base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1) + + # Save state dict with FSDP2 + TP model + fsdp2_tp_model = copy.deepcopy(base_model) + fsdp2_tp_model = parallelize_module( + fsdp2_tp_model, + device_mesh=tp_mesh, + parallelize_plan=tp_parallelize_plan, + ) + for module in fsdp2_tp_model: + fully_shard(module, mesh=dp_mesh) + fully_shard(fsdp2_tp_model, mesh=dp_mesh) + fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1) + + # one-step training to modify state dict + inp = torch.randn((2,), device=self.rank) + base_model(inp).sum().backward() + base_optim.step() + fsdp2_tp_model(inp).sum().backward() + fsdp2_tp_optim.step() + + # obtain the unsharded state dict + base_msd = get_model_state_dict( + base_model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + base_osd = get_optimizer_state_dict( + base_model, + base_optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # obtain FSDP2 + TP state dict + fsdp2_tp_msd = get_model_state_dict( + fsdp2_tp_model, + options=StateDictOptions(full_state_dict=save_full_state_dict), + ) + fsdp2_tp_osd = get_optimizer_state_dict( + fsdp2_tp_model, + fsdp2_tp_optim, + options=StateDictOptions(full_state_dict=save_full_state_dict), + ) + + fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd} + dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir) + + fsdp2_tp_full_msd = get_model_state_dict( + fsdp2_tp_model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + fsdp2_tp_full_osd = get_optimizer_state_dict( + fsdp2_tp_model, + fsdp2_tp_optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # Load state dict into model with TP applied + tp_model = _get_base_model() + tp_model = parallelize_module( + tp_model, + device_mesh=global_mesh_1d, + parallelize_plan=tp_parallelize_plan, + ) + tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1) + + tp_state_dict = { + "model": get_model_state_dict(tp_model), + "optim": get_optimizer_state_dict(tp_model, tp_optim), + } + dcp.load(tp_state_dict, checkpoint_id=self.temp_dir) + tp_model.load_state_dict(tp_state_dict["model"]) + tp_optim.load_state_dict(tp_state_dict["optim"]) + + tp_full_msd = get_model_state_dict( + tp_model, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + tp_full_osd = get_optimizer_state_dict( + tp_model, + tp_optim, + options=StateDictOptions(full_state_dict=True, cpu_offload=True), + ) + + # Compare full state dict to make sure they are the same. + self.assertEqual(base_msd, tp_full_msd) + self.assertEqual(base_osd, tp_full_osd) + self.assertEqual(fsdp2_tp_full_msd, tp_full_msd) + self.assertEqual(fsdp2_tp_full_osd, tp_full_osd) if __name__ == "__main__": diff --git a/test/distributed/checkpoint/test_compatibility.py b/test/distributed/checkpoint/test_compatibility.py index e64e6228986db2..c7bbe1a006c552 100644 --- a/test/distributed/checkpoint/test_compatibility.py +++ b/test/distributed/checkpoint/test_compatibility.py @@ -70,6 +70,31 @@ def test_storage_meta(self) -> None: self.assertEqual(storage_meta.save_id, writer.save_id) self.assertEqual(storage_meta.load_id, reader.load_id) + @with_temp_dir + def test_with_v_2_3(self) -> None: + sd = { + "a": torch.zeros(4, 4), + "dict": { + "dict_a": {"dict_a_1": 1, "dict_a_2": 2}, + "dict_b": {"dict_b_1": 1, "dict_b_2": 2}, + }, + "list": [0, 1, 2, 3, 4, 5], + } + load_sd = { + "a": torch.ones(4, 4), + "dict": { + "dict_a": {"dict_a_1": 2, "dict_a_2": 4}, + "dict_b": {"dict_b_1": 2, "dict_b_2": 4}, + }, + "list": [10, 11, 12, 13, 14, 15], + } + + dcp._version._act_like_version = "2_3" + dcp.save(sd, checkpoint_id=self.temp_dir) + dcp._version._act_like_version = None + dcp.load(load_sd, checkpoint_id=self.temp_dir) + self.assertEqual(sd, load_sd) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_fsdp_model_state.py b/test/distributed/checkpoint/test_fsdp_model_state.py index c90e5c6f267ef5..12d013d5e776e6 100644 --- a/test/distributed/checkpoint/test_fsdp_model_state.py +++ b/test/distributed/checkpoint/test_fsdp_model_state.py @@ -34,7 +34,7 @@ def _test_fsdp_model_state(self, process_group) -> None: "model": model.state_dict(), } - dist_cp.save_state_dict( + dist_cp.save( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=DefaultSavePlanner(), @@ -55,7 +55,7 @@ def _test_fsdp_model_state(self, process_group) -> None: "model": model_2.state_dict(), } - dist_cp.load_state_dict( + dist_cp.load( state_dict=state_dict, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=DefaultLoadPlanner(), diff --git a/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py b/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py index adf2a6d3496f49..728524b4011d6b 100644 --- a/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py +++ b/test/distributed/checkpoint/test_fsdp_tp_checkpoint_conversion.py @@ -40,7 +40,7 @@ def test_fsdp_to_tp(self): fsdp_state_dict = fsdp_model.state_dict() # save fsdp_state_dict to storage - dist_cp.save_state_dict( + dist_cp.save( state_dict=fsdp_state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), ) diff --git a/test/distributed/checkpoint/test_hsdp_checkpoint.py b/test/distributed/checkpoint/test_hsdp_checkpoint.py index 02904baf6a7977..23ca7c9463be7f 100644 --- a/test/distributed/checkpoint/test_hsdp_checkpoint.py +++ b/test/distributed/checkpoint/test_hsdp_checkpoint.py @@ -94,7 +94,7 @@ def test_hsdp_checkpoint(self, is_even_sharded_model) -> None: state_dict = {"model": model.state_dict()} state_dict_to_save = deepcopy(state_dict) - dist_cp.save_state_dict( + dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=DefaultSavePlanner(), @@ -113,7 +113,7 @@ def test_hsdp_checkpoint(self, is_even_sharded_model) -> None: self.assertEqual(v1.placements, v2.placements) self.assertNotEqual(v1.to_local(), v2.to_local()) - dist_cp.load_state_dict( + dist_cp.load( state_dict=state_dict_to_save, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=DefaultLoadPlanner(), diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index e299e3ee8269ed..e3dfb782ad5657 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -86,18 +86,19 @@ def _test_save_load( model, optim, copy_optim, dist_model, dist_optim = init_model_optim() # Train 10 steps. + _dist_optim = [dist_optim] if not isinstance(dist_optim, list) else dist_optim for i in range(10): + optim.zero_grad() + for d_optim in _dist_optim: + d_optim.zero_grad() + batch = torch.rand(8, 100, device="cuda") model(batch).sum().backward() - optim.step() dist_model(batch).sum().backward() - if not isinstance(dist_optim, list): - dist_optim.step() - dist_optim.zero_grad() - else: - for _dist_optim in dist_optim: - _dist_optim.zero_grad() - optim.zero_grad() + + optim.step() + for d_optim in _dist_optim: + d_optim.step() # Get the state_dict, and compare the result msd = model.state_dict() @@ -176,8 +177,8 @@ def init_model_optim(): device_mesh = init_device_mesh("cuda", (self.world_size,)) orig_model = CompositeParamModel(device=torch.device("cuda")) - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3, foreach=True) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) if wrapping: strategy = set(wrapping) else: @@ -204,7 +205,7 @@ def init_model_optim(): if compile_model: dist_model = torch.compile(dist_model) - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3, foreach=True) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4, foreach=True) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim) @@ -218,7 +219,11 @@ def test_fsdp(self) -> None: "use_composable": [True, False], "use_dtensor": [True, False], "wrapping": [(), (nn.Linear, UnitModule)], - "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], + "optimizer_class": [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + ], }, self._test_fsdp, ) @@ -248,10 +253,10 @@ def _test_fsdp2( def init_model_optim(): orig_model = CompositeParamModel(device=torch.device("cuda")) orig_optim = optimizer_class( - orig_model.parameters(), lr=1e-3, foreach=foreach + orig_model.parameters(), lr=1e-4, foreach=foreach ) copy_optim = optimizer_class( - orig_model.parameters(), lr=1e-3, foreach=foreach + orig_model.parameters(), lr=1e-4, foreach=foreach ) dist_model = FSDP2( @@ -262,7 +267,7 @@ def init_model_optim(): if compile_model: dist_model = torch.compile(dist_model) dist_optim = optimizer_class( - dist_model.parameters(), lr=1e-3, foreach=foreach + dist_model.parameters(), lr=1e-4, foreach=foreach ) return orig_model, orig_optim, copy_optim, dist_model, dist_optim @@ -284,13 +289,13 @@ def test_fsdp2(self) -> None: def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None: def init_model_optim(): orig_model = CompositeParamModel(device=torch.device("cuda")) - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) if use_composable: dist_model = replicate(copy.deepcopy(orig_model)) else: dist_model = DDP(copy.deepcopy(orig_model)) - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim) @@ -301,7 +306,11 @@ def test_ddp(self) -> None: self.run_subtests( { "use_composable": [True, False], - "optimizer_class": [torch.optim.Adam, torch.optim.AdamW], + "optimizer_class": [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + ], }, self._test_ddp, ) @@ -320,8 +329,8 @@ def init_model_optim(): orig_model.u1.parameters(), orig_model.u2.parameters() ): param.requires_grad = False - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) dist_model = copy.deepcopy(orig_model) if use_composable: replicate(dist_model.l) @@ -336,13 +345,13 @@ def init_model_optim(): ) if optim_in_backward: _apply_optimizer_in_backward( - optimizer_class, dist_model.parameters(), {"lr": 1e-3} + optimizer_class, dist_model.parameters(), {"lr": 1e-4} ) dist_optim = [ p._in_backward_optimizers[0] for p in dist_model.parameters() ] else: - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim, test_frozen) @@ -395,10 +404,10 @@ def test_apply_optimizer_in_backward(self) -> None: def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None: def init_model_optim(): orig_model = CompositeParamModel(device=torch.device("cuda")) - orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3) - copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3) + orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) + copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) model_copy = copy.deepcopy(orig_model) - optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3) + optim_copy = optimizer_class(model_copy.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, model_copy, optim_copy self._test_save_load(init_model_optim) @@ -445,7 +454,7 @@ def _test_cpu_offload_full_state_dict( device_mesh=device_mesh, ) - dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3) + dist_optim = optimizer_class(dist_model.parameters(), lr=1e-4) mst, ost = get_state_dict( dist_model, @@ -887,10 +896,10 @@ def forward(self, input): def init_model_optim(): device_mesh = init_device_mesh("cuda", (self.world_size,)) orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) - orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) - copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3) + orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) + copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) - dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3) + dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-4) return orig_model, orig_optim, copy_optim, dist_model, dist_optim self._test_save_load(init_model_optim) @@ -958,7 +967,7 @@ def setUp(self) -> None: @skip_if_lt_x_gpu(1) def test_no_dist(self) -> None: model = CompositeParamModel(device=torch.device("cuda")) - optim = torch.optim.AdamW(model.parameters(), lr=1e-3) + optim = torch.optim.AdamW(model.parameters(), lr=1e-4) self.assertFalse(dist.is_initialized()) msd = get_model_state_dict( diff --git a/test/distributed/checkpoint/test_state_dict_utils.py b/test/distributed/checkpoint/test_state_dict_utils.py index 4bd4f9cf21f214..757cf77c067395 100644 --- a/test/distributed/checkpoint/test_state_dict_utils.py +++ b/test/distributed/checkpoint/test_state_dict_utils.py @@ -12,8 +12,7 @@ _gather_state_dict, _offload_state_dict_to_cpu, ) -from torch.distributed._tensor import DTensor -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor import DTensor, Shard from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 9f55785810ca86..98ff8f1a309749 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -6,6 +6,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import ctypes import multiprocessing import os @@ -362,6 +363,9 @@ def test_pcontext_wait(self): self.assertTrue(pc._stderr_tail.stopped()) self.assertTrue(pc._stdout_tail.stopped()) + def test_pcontext_wait_on_a_child_thread(self): + asyncio.run(asyncio.to_thread(self.test_pcontext_wait)) + def test_multiprocess_context_close(self): pc = start_processes( name="sleep", diff --git a/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py b/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py index 2838d918964021..89329c380f391f 100644 --- a/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py +++ b/test/distributed/elastic/rendezvous/c10d_rendezvous_backend_test.py @@ -25,6 +25,7 @@ C10dRendezvousBackend, create_backend, ) +from torch.distributed.elastic.utils.distributed import get_free_port class TCPStoreBackendTest(TestCase, RendezvousBackendTestMixin): @@ -69,9 +70,11 @@ def setUp(self) -> None: # For testing, the default parameters used are for tcp. If a test # uses parameters for file store, we set the self._params to # self._params_filestore. + + port = get_free_port() self._params = RendezvousParameters( backend="dummy_backend", - endpoint="localhost:29300", + endpoint=f"localhost:{port}", run_id="dummy_run_id", min_nodes=1, max_nodes=1, @@ -95,7 +98,7 @@ def setUp(self) -> None: self._expected_temp_dir = tempfile.gettempdir() self._expected_endpoint_host = "localhost" - self._expected_endpoint_port = 29300 + self._expected_endpoint_port = port self._expected_store_type = TCPStore self._expected_read_timeout = timedelta(seconds=10) @@ -173,11 +176,14 @@ def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_al def test_create_backend_returns_backend_if_endpoint_port_is_not_specified( self, ) -> None: - self._params.endpoint = self._expected_endpoint_host + # patch default port and pass endpoint with no port specified + with mock.patch( + "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.DEFAULT_PORT", + self._expected_endpoint_port, + ): + self._params.endpoint = self._expected_endpoint_host - self._expected_endpoint_port = 29400 - - self._assert_create_backend_returns_backend() + self._assert_create_backend_returns_backend() def test_create_backend_returns_backend_if_endpoint_file_is_not_specified( self, @@ -206,16 +212,6 @@ def test_create_backend_returns_backend_if_read_timeout_is_not_specified( self._assert_create_backend_returns_backend() - def test_create_backend_returns_backend_with_libuv(self) -> None: - self._params.config["use_libuv"] = "true" - - self._assert_create_backend_returns_backend() - - def test_create_backend_returns_backend_without_libuv(self) -> None: - self._params.config["use_libuv"] = "false" - - self._assert_create_backend_returns_backend() - def test_create_backend_raises_error_if_store_is_unreachable(self) -> None: self._params.config["is_host"] = "false" self._params.config["read_timeout"] = "2" diff --git a/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py b/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py index 67cc97267a28ed..90ea76a48fadc7 100644 --- a/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py +++ b/test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py @@ -1597,6 +1597,23 @@ def test_create_handler_records_and_raises_exceptions(self, record_mock) -> None create_handler(self._store, self._backend, self._params) record_mock.assert_called_once() + def test_create_handler_rdzv_local_addr(self) -> None: + params = RendezvousParameters( + backend=self._backend.name, + endpoint="dummy_endpoint", + run_id="dummy_run_id", + min_nodes=1, + max_nodes=1, + join_timeout="50", + last_call_timeout="60", + close_timeout="70", + local_addr="127.0.0.2", + ) + store = HashStore() + handler = create_handler(store, self._backend, params) + rdzv_info = handler.next_rendezvous() + self.assertEqual(rdzv_info.bootstrap_store_info.master_addr, "127.0.0.2") + def _ignore_exception(exception_type: Exception, fn: Callable): try: @@ -1656,7 +1673,7 @@ def _create_handler(self, **kwargs) -> DynamicRendezvousHandler: "min_nodes": 2, "max_nodes": 2, "join_timeout": "5", - "local_addr": f"address_{len(self._handlers)}", + "local_addr": f"127.0.0.{len(self._handlers)}", } params.update(**kwargs) @@ -1714,7 +1731,7 @@ def test_redundancy_list(self) -> None: state_and_token = self._backend.get_state() state = pickle.loads(state_and_token[0]) addresses = [node.addr for node in state.redundancy_list] - self.assertListEqual(addresses, ["address_2"]) + self.assertListEqual(addresses, ["127.0.0.2"]) def test_redundancy_transition_to_wait_list_then_join_rendezvous(self) -> None: handler1 = self._create_handler( diff --git a/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py new file mode 100644 index 00000000000000..4d304bef1bc234 --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_rendezvous_test.py @@ -0,0 +1,38 @@ +# Owner(s): ["oncall: r2p"] + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import pathlib +import sys +import unittest + +import torch.distributed.elastic.rendezvous as rdvz + + +BACKEND_NAME = "testbackend" +TEST_PACKAGE_PATH = "/out_of_tree_test_package/src" + + +class OutOfTreeRendezvousTest(unittest.TestCase): + def test_out_of_tree_handler_loading(self): + current_path = str(pathlib.Path(__file__).parent.resolve()) + rdvz._register_out_of_tree_handlers() + registry_dict = rdvz.rendezvous_handler_registry._registry + + # test backend should not be registered as a backend + self.assertFalse(BACKEND_NAME in registry_dict) + + # Including testbackend in python path + sys.path.append(current_path + TEST_PACKAGE_PATH) + + # Registering the out of tree handlers again + rdvz._register_out_of_tree_handlers() + + # test backend should be registered as a backend + self.assertTrue(BACKEND_NAME in registry_dict) + + # Removing testbackend from python path + sys.path.remove(current_path + TEST_PACKAGE_PATH) diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml new file mode 100644 index 00000000000000..32e177bf73fe6b --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_test_package/pyproject.toml @@ -0,0 +1,6 @@ +[project] +name = "testbackend" +version = "0.0.1" + +[project.entry-points.'torchrun.handlers'] +testbackend = 'testbackend:test_handler' \ No newline at end of file diff --git a/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py new file mode 100644 index 00000000000000..1fbf5b4c2dfb10 --- /dev/null +++ b/test/distributed/elastic/rendezvous/out_of_tree_test_package/src/testbackend/__init__.py @@ -0,0 +1,2 @@ +def test_handler(): + return "" diff --git a/test/distributed/elastic/timer/file_based_local_timer_test.py b/test/distributed/elastic/timer/file_based_local_timer_test.py index 490e4a9ce37a7c..c06f3520bac853 100644 --- a/test/distributed/elastic/timer/file_based_local_timer_test.py +++ b/test/distributed/elastic/timer/file_based_local_timer_test.py @@ -6,6 +6,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import multiprocessing as mp +import os import signal import time import unittest @@ -37,7 +38,7 @@ class FileTimerTest(TestCase): def setUp(self): super().setUp() self.max_interval = 0.01 - self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4()) + self.file_path = f"/tmp/test_file_path_{os.getpid()}_{uuid.uuid4()}" self.server = timer.FileTimerServer( self.file_path, "test", self.max_interval ) @@ -204,7 +205,7 @@ def test_send_request_without_server(self): class FileTimerServerTest(TestCase): def setUp(self): super().setUp() - self.file_path = "/tmp/test_file_path_" + str(uuid.uuid4()) + self.file_path = f"/tmp/test_file_path_{os.getpid()}_{uuid.uuid4()}" self.max_interval = 0.01 self.server = timer.FileTimerServer( self.file_path, "test", self.max_interval diff --git a/test/distributed/elastic/utils/distributed_test.py b/test/distributed/elastic/utils/distributed_test.py index ded2c7e9d458f0..c0562a2a3e775a 100644 --- a/test/distributed/elastic/utils/distributed_test.py +++ b/test/distributed/elastic/utils/distributed_test.py @@ -131,6 +131,7 @@ def test_create_store_with_libuv_support(self): wait_for_workers = False localhost = socket.gethostname() + os.environ["USE_LIBUV"] = "0" store = create_c10d_store( is_server=True, server_addr=localhost, @@ -138,10 +139,12 @@ def test_create_store_with_libuv_support(self): timeout=2, world_size=world_size, wait_for_workers=wait_for_workers, - use_libuv=False, ) self.assertFalse(store.libuvBackend) + del os.environ["USE_LIBUV"] + assert "USE_LIBUV" not in os.environ + # libuv backend is enabled by default store = create_c10d_store( is_server=True, server_addr=localhost, @@ -149,7 +152,6 @@ def test_create_store_with_libuv_support(self): timeout=2, world_size=world_size, wait_for_workers=wait_for_workers, - use_libuv=True, ) self.assertTrue(store.libuvBackend) diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index fd0c6df0cc9786..bc5f010e927bbd 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -25,6 +25,7 @@ def create_one_event( state="scheduled", collective_seq_id=0, p2p_seq_id=0, + output_dtypes="float32", ): return { "profiling_name": f"nccl:{collectcive_name}", @@ -32,6 +33,8 @@ def create_one_event( "process_group": pg_info, "input_sizes": input_sizes, "output_sizes": output_sizes, + "input_dtypes": "float32", + "output_dtypes": output_dtypes, "collective_seq_id": str(collective_seq_id), "p2p_seq_id": str(p2p_seq_id), } @@ -43,49 +46,67 @@ def test_match_one_event(self): "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) membership = {"0": {0, 1}} - self.assertEqual(match_one_event(e1, e1, membership), MatchState.FULLY_MATCHED) + self.assertEqual( + match_one_event(e1, e1, membership, "0"), MatchState.FULLY_MATCHED + ) e2 = create_one_event( "all_gather", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) self.assertEqual( - match_one_event(e1, e2, membership), MatchState.COLLECTIVE_TYPE_MISMATCH + match_one_event(e1, e2, membership, "0"), + MatchState.COLLECTIVE_TYPE_MISMATCH, ) e3 = create_one_event( - "alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 + "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) e4 = create_one_event( - "alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 + "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) - self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED) + self.assertEqual(match_one_event(e3, e4, membership, "0"), MatchState.UNDECIDED) e5 = create_one_event( "all_reduce", ("0", "default"), [[5, 4]], [[4, 4]], "scheduled", 1, 1 ) self.assertEqual( - match_one_event(e1, e5, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH + match_one_event(e1, e5, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH ) e6 = create_one_event( "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 1, 2 ) self.assertEqual( - match_one_event(e1, e6, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH + match_one_event(e1, e6, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH ) e7 = create_one_event( "all_reduce", ("0", "default"), [[4, 4]], [[5, 4]], "scheduled", 2 ) self.assertEqual( - match_one_event(e7, e7, membership), MatchState.SIZE_OR_SYNTAX_MISMATCH + match_one_event(e7, e7, membership, "0"), MatchState.SIZE_OR_SYNTAX_MISMATCH ) e9 = create_one_event( "all_reduce", ("0", "default"), [[4, 4]], [[4, 4]], "completed", 1 ) self.assertEqual( - match_one_event(e1, e9, membership), MatchState.COLLECTIVE_STATE_MISMATCH + match_one_event(e1, e9, membership, "0"), + MatchState.COLLECTIVE_STATE_MISMATCH, + ) + + e10 = create_one_event( + "all_reduce", + ("0", "default"), + [[4, 4]], + [[4, 4]], + "completed", + 1, + output_dtypes="float16", + ) + self.assertEqual( + match_one_event(e10, e9, membership, "0"), + MatchState.COLLECTIVE_DTYPE_MISMATCH, ) diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index c0a2710e9c9023..5f04bc8045dbbd 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -14,12 +14,12 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, ShardingStrategy, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index e36834617aa072..adc338dcf9a97f 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -55,7 +55,13 @@ def get_a_tensor(): return t @dataclass - class SomeDataClass: + class NonFrozenDataClass: + some_key: str + some_float: float + some_tensor: List[torch.Tensor] + + @dataclass(frozen=True) + class FrozenDataClass: some_key: str some_float: float some_tensor: List[torch.Tensor] @@ -65,7 +71,10 @@ class SomeDataClass: data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) data.insert(0, {"x", get_a_tensor(), get_a_tensor()}) data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2})) - data.append({"abc": SomeDataClass("some_key", 1.0, [get_a_tensor()])}) + data.append( + {"non_frozen_ds": NonFrozenDataClass("some_key", 1.0, [get_a_tensor()])} + ) + data.append({"frozen_ds": FrozenDataClass("some_key", 1.0, [get_a_tensor()])}) od = OrderedDict() od["k"] = "value" data.append(od) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index a17fd7cfb263a3..67edb211b9f1ed 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -403,7 +403,7 @@ def _check_same_model_params( ) @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess def test_step(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step()`` interface.""" @@ -443,7 +443,7 @@ def test_step(self): self.assertEqual(m.bias, m_zero.bias) @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess def test_step_with_closure(self): """Check that ZeroRedundancyOptimizer properly exposes the ``step(closure)`` interface.""" @@ -663,7 +663,7 @@ def test_multiple_param_groups(self): torch.testing.assert_close(layer1.bias, layer3.bias) @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess def test_collect_shards(self): """Check the state consolidation mechanism and the state dict exposed by ZeroRedundancyOptimizer.""" @@ -1383,7 +1383,7 @@ def _test_ddp_zero_overlap( @common_distributed.skip_if_win32() @common_distributed.requires_nccl() @common_distributed.skip_if_no_gpu - @common_distributed.skip_if_rocm + @common_distributed.skip_if_rocm_multiprocess @parametrize( "use_gpu", [True], diff --git a/test/distributed/pipelining/test_backward.py b/test/distributed/pipelining/test_backward.py index b99e303ca9a61b..328eddcce50695 100644 --- a/test/distributed/pipelining/test_backward.py +++ b/test/distributed/pipelining/test_backward.py @@ -77,7 +77,7 @@ def test_stage_backward_input(self): dinputs, param_groups = stage_backward_input( stage_outputs=(loss,), output_grads=None, - stage_inputs=[x], + input_values=[x], weights=mod.parameters(), ) @@ -112,7 +112,7 @@ def test_stage_backward_weight(self): dinputs, param_groups = stage_backward_input( stage_outputs=(loss,), output_grads=None, - stage_inputs=[x], + input_values=[x], weights=mod.parameters(), ) @@ -160,7 +160,7 @@ def test_stage_backward_weight_multiple_iters(self): dinputs, param_groups = stage_backward_input( stage_outputs=(loss,), output_grads=None, - stage_inputs=[x], + input_values=[x], weights=mod.parameters(), ) diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 327359a5807cb5..951e77188364c3 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -29,8 +29,10 @@ ) from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] instantiate_parametrized_tests, + MI300_ARCH, parametrize, run_tests, + runOnRocmArch, TestCase, ) from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule @@ -228,6 +230,7 @@ def func(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertIn("fused_all_gather_matmul", code) self.assertNotIn("all_gather_into_tensor", code) + @runOnRocmArch(MI300_ARCH) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("gather_dim", [0, 1, 2]) @@ -324,6 +327,7 @@ def func(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: self.assertIn("fused_matmul_reduce_scatter", code) self.assertNotIn("reduce_scatter_tensor", code) + @runOnRocmArch(MI300_ARCH) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @parametrize("A_dims", [2, 3]) @parametrize("scatter_dim", [0, 1, 2]) diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 0c470890241214..43662c4d6cf2d4 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -15,11 +15,11 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, ) +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, loss_parallel, @@ -118,8 +118,6 @@ def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=Fal output = model(inp) output.sum().backward() - from torch.distributed._tensor.debug import CommDebugMode - comm_mode = CommDebugMode() with comm_mode: output_tp = model_tp(inp) diff --git a/test/distributed/tensor/parallel/test_tp_random_state.py b/test/distributed/tensor/parallel/test_tp_random_state.py index 41ca79121d4b48..b9f73a70430d46 100644 --- a/test/distributed/tensor/parallel/test_tp_random_state.py +++ b/test/distributed/tensor/parallel/test_tp_random_state.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: distributed"] import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.random as random +import torch.distributed.tensor._random as random from torch.distributed._tensor import init_device_mesh, Replicate from torch.distributed.tensor.parallel.api import parallelize_module from torch.distributed.tensor.parallel.style import ColwiseParallel diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py index ebd6d82a8c772e..28ff10bab099ca 100644 --- a/test/distributed/tensor/parallel/test_tp_style.py +++ b/test/distributed/tensor/parallel/test_tp_style.py @@ -12,8 +12,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.debug import CommDebugMode -from torch.distributed._tensor.placement_types import _Partial +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import parallelize_module from torch.distributed.tensor.parallel.style import ( ColwiseParallel, @@ -22,6 +21,7 @@ RowwiseParallel, SequenceParallel, ) +from torch.distributed.tensor.placement_types import _Partial from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 3e5538d57e38ae..d96abb1ca82675 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1820,6 +1820,7 @@ def test_init_process_group_optional_backend(self): def test_init_process_group_for_all_backends(self): for backend in dist.Backend.backend_list: + excepted_backend = backend # skip if the backend is not available on the system if backend == dist.Backend.UNDEFINED: continue @@ -1835,6 +1836,11 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue + # Multi-threaded PG is defined as a pure python class. + # Its pg.name() does not going through Pybind, so its backend name + # is still "threaded" instead of "custom". + elif backend != "threaded": + excepted_backend = "custom" with tempfile.NamedTemporaryFile(delete=False) as f: store = dist.FileStore(f.name, self.world_size) @@ -1847,7 +1853,7 @@ def test_init_process_group_for_all_backends(self): pg = c10d._get_default_group() self.assertEqual(pg.rank(), self.rank) self.assertEqual(pg.size(), self.world_size) - self.assertEqual(pg.name(), str(backend)) + self.assertEqual(pg.name(), str(excepted_backend)) dist.destroy_process_group() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 64c51f791fddd0..78688b0e6a70cc 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -48,7 +48,7 @@ requires_nccl, requires_nccl_version, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, TEST_SKIPS, with_dist_debug_levels, with_nccl_blocking_wait, @@ -237,6 +237,8 @@ def setUp(self): self.test_nan_assert_float32.__wrapped__, self.test_nan_assert_float64.__wrapped__, self.test_nan_assert_bfloat16.__wrapped__, + self.test_nan_assert_float8_e4m3fn.__wrapped__, + self.test_nan_assert_float8_e5m2.__wrapped__, ] # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests @@ -347,25 +349,103 @@ def test_close_pg(self): not (TEST_MULTIGPU and CUDA_12_AND_ABOVE), "NCCL test requires 2+ GPUs and Device side assert could cause unexpected errors in lower versions of CUDA", ) - @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16]) - @skip_if_rocm + @parametrize( + "type", + [ + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], + ) + @skip_if_rocm_multiprocess def test_nan_assert(self, type): + # Expecting a device-side error when NaN is detected os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) pg = self._create_process_group_nccl(store, self.opts()) device = self.rank_to_GPU[self.rank][0] - size = (10, 10) - nan_tensor = torch.full(size, self.rank, dtype=type, device=device) + # Cover different buffer sizes + if type == torch.float64: + size = (1024,) # 1K elements + elif type == torch.float32: + size = (1024, 1024) # 1M elements + elif type == torch.float16: + size = (1024, 1024, 1024) # 1G elements + else: + size = (1,) # 1 element + + # Note: currently we cannot fill values into a FP8 tensor, thus we + # create the NaN tensor in float32 type and cast it to FP8 + if type == torch.float8_e4m3fn or type == torch.float8_e5m2: + init_type = torch.float32 + else: + init_type = type + + nan_tensor = torch.zeros(*size, dtype=init_type, device=device) # randomly pick an nan element - i = random.randint(0, nan_tensor.size(0) - 1) - j = random.randint(0, nan_tensor.size(1) - 1) - nan_tensor[i, j] = float("nan") + index = tuple([random.randrange(size[i]) for i in range(len(size))]) + nan_tensor[index] = float("nan") + if init_type != type: + # Now cast to the targeted dtype + nan_tensor = nan_tensor.to(type) + + output = torch.empty(self.world_size, *size, dtype=type, device=device) with self.assertRaises(RuntimeError): - pg.allreduce(nan_tensor) + # Note: using all-gather here bc FP8 types do not support reduce ops + # at the moment + pg._allgather_base(output, nan_tensor) dist.destroy_process_group() # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_rank_filter(self): + # Putting NaN at recv buffer, program should not fail as NaN checker + # should not check on receive buffer + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + if self.rank != 0: + # Putting NaN at recv buffer + t[1, 1] = float("nan") + # Against broadcast + c10d.broadcast(t, 0) + # Against P2P + if self.rank == 0: + c10d.send(t, 1) + elif self.rank == 1: + c10d.recv(t, 0) + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_nan_check(self): + # Not expecting an error, NaN check should not make legit code fail + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank + t = torch.ones(3, 4, dtype=torch.bfloat16, device=device) + c10d.broadcast(x, src=0) + c10d.all_reduce(t) + c10d.barrier() + c10d.destroy_process_group() + # reset env + os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): @@ -1644,7 +1724,7 @@ def test_grad_layout_1devicemodule_1replicaperprocess(self): @requires_nccl() @skip_if_lt_x_gpu(4) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_grad_layout_2devicemodule(self): int_devices = gpus_for_rank(self.world_size)[self.rank][:2] dev0 = torch.device("cuda:" + str(int_devices[0])) @@ -2416,7 +2496,7 @@ def _run_all_reduce(self, pg): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_but_pass_in_sandcastle("Test does not pass when run locally") def test_nccl_errors_nonblocking(self): # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test @@ -2478,7 +2558,7 @@ def _test_nccl_errors_blocking(self, func): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_clean_exit(self): self._test_nccl_errors_blocking(lambda: sys.exit(0)) @@ -2486,7 +2566,7 @@ def test_nccl_errors_blocking_clean_exit(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_nonzero_exit(self): self._test_nccl_errors_blocking(lambda: sys.exit(1)) @@ -2494,7 +2574,7 @@ def test_nccl_errors_blocking_nonzero_exit(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess @skip_but_pass_in_sandcastle( "Frequently times out see https://github.com/pytorch/pytorch/issues/58920" ) @@ -2505,7 +2585,7 @@ def test_nccl_errors_blocking_abort(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_sigkill(self): self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL)) @@ -2513,7 +2593,7 @@ def test_nccl_errors_blocking_sigkill(self): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(3) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_blocking_sigterm(self): self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM)) @@ -2729,7 +2809,7 @@ def test_all_reduce_coalesced_manager_nccl(self): @requires_nccl() @skip_if_lt_x_gpu(2) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_intra_node_comm_all_reduce(self): from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter from torch.testing._internal.common_cuda import SM80OrLater @@ -3552,7 +3632,7 @@ def started_or_scheduled(self, timing_enabled): class NCCLTraceTest(NCCLTraceTestBase): def _verify_trace(self, t, include_collectives, timing_enabled, is_json): ver = t["version"] - self.assertEqual(ver, "2.3") + self.assertEqual(ver, "2.4") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] @@ -4264,7 +4344,7 @@ def _check_return_codes(self, elapsed_time): @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @skip_if_lt_x_gpu(2) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_nccl_errors_dump(self): os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000" diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 09db4ac06dab35..7cbbe1a9e71450 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -5,12 +5,6 @@ import torch import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor -from torch.distributed._tensor._collective_utils import ( - mesh_broadcast, - mesh_scatter, - unpad_tensor, -) -from torch.distributed._tensor.placement_types import _Partial, Shard from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.distributed_c10d import ( _get_default_group, @@ -22,6 +16,12 @@ is_nccl_available, ProcessGroup, ) +from torch.distributed.tensor._collective_utils import ( + mesh_broadcast, + mesh_scatter, + unpad_tensor, +) +from torch.distributed.tensor.placement_types import _Partial, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -136,6 +136,10 @@ def test_get_local_rank(self): self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp")) self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp")) + # Verify flattened mesh local rank correctness. + flattened_mesh = mesh_2d["dp", "tp"]._flatten() + self.assertEqual(flattened_mesh.get_local_rank(), self.rank) + @with_comms def test_device_mesh_2d(self): mesh_tensor = torch.arange(4).reshape(2, 2) @@ -232,7 +236,8 @@ def test_set_mesh_dim_group_options(self): mesh_tensor = torch.arange(4).reshape(2, 2) mesh = DeviceMesh(device_type, mesh_tensor) - self.assertEqual(mesh.get_group(1)._get_backend_name(), "fake") + # Fake pg only have BackendType as BackendType::CUSTOM. + self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom") class DeviceMeshTestNDim(DTensorTestBase): @@ -692,6 +697,17 @@ def test_get_mesh_dim_by_name(self): self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0) self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1) + @with_comms + def test_get_all_submeshes(self): + mesh_2d = init_device_mesh( + self.device_type, (2, 4), mesh_dim_names=("replicate", "shard") + ) + all_submeshes = _mesh_resources._get_all_submeshes(mesh_2d, "replicate") + self.assertEqual(len(all_submeshes), 4) + self.assertEqual( + all(submesh.mesh.numel() == 2 for submesh in all_submeshes), True + ) + class DeviceMeshCollectiveTest(DTensorTestBase): @property diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index afe299106b65f8..d6357678c94f97 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -127,6 +127,24 @@ def get_mutating_model( return m, inputs, outputs +class ForcedGetAttrMod(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.linear = torch.nn.Linear(1, 1) + self.__dict__["forced_linear"] = torch.nn.Linear(1, 1).to(device=device) + self.counter = 0 + + def forward(self, x): + self.counter += 1 + return x * self.linear(x) * self.forced_linear.weight + + +def get_forced_getattr_module(device): + mod = ForcedGetAttrMod(device).to(device=device) + x = torch.randn(1, 1, device=device) + return mod, x, mod(x) + + class ToyInnerModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -423,7 +441,6 @@ def forward(self, x, y): opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) - @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534" @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) def test_unbacked_symbol_splitting_no_binding(self): class Model(nn.Module): @@ -616,21 +633,46 @@ def test_fsdp_aot_eager(self): def test_fsdp_setattr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) + from torch._dynamo.utils import counters + + counters.clear() m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) - prof = torch._dynamo.utils.CompileProfiler() - fsdp_m = torch.compile(fsdp_m, backend=prof, fullgraph=False) + fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + self.assertEqual(len(counters["graph_break"]), 1) + first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015 + self.assertTrue("setattr" not in first_graph_break) + + @config.patch(inline_inbuilt_nn_modules=False) + @config.patch(enable_compiler_collectives=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_unspecialized_forced_getattr_no_inline(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + from torch._dynamo.utils import counters + + counters.clear() + m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @config.patch(enable_compiler_collectives=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_unspecialized_forced_getattr_inline(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + from torch._dynamo.utils import counters + + counters.clear() + m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - FileCheck().check("Torchdynamo Profiler Report").check( - "Graph Breaks" - ).check_not( - "setattr(FSDPManagedNNModuleVariable(MutatingModel), state, ...)" - ).check_not( - "setattr(FSDPManagedNNModuleVariable(FullyShardedDataParallel), _is_root, ...)" - ).run( - prof.report() - ) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) @@ -913,6 +955,113 @@ def f(x, y): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_dim_mismatch(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(x, y): + zx = x.shape + zy = y.shape + return x.sum() + y.sum() + + if self.rank == 0: + dataloader = [[4, 2]] + else: + dataloader = [[3]] + + for data in dataloader: + f( + torch.randn(data, device=self.rank), + torch.randn(data, device=self.rank), + ) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_missing_source(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(rank, xs): + return xs[rank].sum() + + xs = [] + for _ in range(self.world_size): + xs.append(torch.randn(10, device=self.rank)) + + f(self.rank, xs) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_scalar_missing_source(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(rank, xs): + return torch.tensor(xs[rank], device=self.rank) + + xs = [] + for i in range(self.world_size): + xs.append(10 + i) + + f(self.rank, xs) + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @config.patch(enable_compiler_collectives=True) + def test_compiler_collectives_type_mismatch(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(x): + if isinstance(x, int): + return torch.tensor(x, device=self.rank) + else: + return x.sum() + + if self.rank == 0: + x = torch.randn(10, device=self.rank) + else: + x = 12 + f(x) + + # This deadlocks, I guess we don't support this + """ + if self.rank == 0: + x = torch.randn(12, device=self.rank) + else: + x = 10 + f(x) + """ + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index 2876bdd275a54b..c4972a46405862 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -384,20 +384,23 @@ def test_all_reduce(self): self.assertIsNone(x.grad) -class TestMakeFx(MultiThreadedTestCase): - @property - def world_size(self): - return 2 - +class TestMakeFx(TestCase): def setUp(self): - super().setUp() - self._spawn_threads() + # make_fx is not thread-safe due to patching nd mutating global states + # so create a fake_pg. + self.rank = 0 + self.world_size = 2 + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) def tearDown(self): super().tearDown() - # race condition with threads causes is_fx_tracing flag to be set incorrectly. - torch.fx._symbolic_trace._is_fx_tracing_flag = False self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing()) def test_all_reduce_tracing(self): diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index af59f6188581d8..a19183425d3907 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1013,22 +1013,17 @@ def func(inp): return ar input = torch.ones(4, 4, device="cuda", requires_grad=True) - # TODO implement backwards - with self.assertRaisesRegex( - RuntimeError, - "element 0 of tensors does not require grad and does not have a grad_fn", - ): - compiled = torch.compile( - func, backend="aot_eager" - ) # inductor bug with single-op allreduce graph - out = compiled(input) - out.sum().backward() - - correct_input = input.clone().detach().requires_grad_() - correct = func(correct_input) - correct.sum().backward() - self.assertTrue(same(out, correct)) - self.assertTrue(same(input.grad, correct_input.grad)) + compiled = torch.compile( + func, backend="aot_eager" + ) # inductor bug with single-op allreduce graph + out = compiled(input) + out.sum().backward() + + correct_input = input.clone().detach().requires_grad_() + correct = func(correct_input) + correct.sum().backward() + self.assertTrue(same(out, correct)) + self.assertTrue(same(input.grad, correct_input.grad)) def test_meta(self): x = torch.rand((2, 3, 4), device="meta") diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index c27ad2f10f9ed0..ec6aa7f903bf7b 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -28,7 +28,9 @@ def requires_cuda_p2p_access(): cuda_p2p_access_available = ( - torch.cuda.is_available() and torch.cuda.device_count() >= 2 + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) + and torch.cuda.device_count() >= 2 ) num_devices = torch.cuda.device_count() for i in range(num_devices - 1): diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 843f11b51c990f..c600ca88fd2a84 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1,9 +1,11 @@ # Owner(s): ["module: dynamo"] +import contextlib import copy import functools import math import unittest # noqa: F811 from importlib import import_module +from typing import Set import torch import torch._dynamo.config @@ -83,6 +85,14 @@ def match_rng_op(node, op): return gm +def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph + return_node = list(graph.nodes)[-1] + assert return_node.target == "output" + for x in return_node.args[0]: + fwd_outputs.add(str(x)) + + class _InvalidContext: def __init__(self) -> None: pass @@ -126,18 +136,35 @@ def _custom_policy(ctx, func, *args, **kwargs): class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): - def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): + def _validate( + self, + fn, + backend, + *args, + skip_check=False, + fullgraph=True, + compiled_autograd=False, + ): cloned_args = [] for arg in args: cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad)) + cloned_fn = copy.deepcopy(fn) + torch.manual_seed(0) expected = fn(*args) expected.sum().backward() torch.manual_seed(0) - result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args) - result.sum().backward() + compiled_fn = torch.compile(cloned_fn, fullgraph=fullgraph, backend=backend) + ctx = contextlib.nullcontext() + if compiled_autograd: + ctx = torch._dynamo.compiled_autograd.enable( + lambda gm: torch.compile(gm, fullgraph=fullgraph, backend=backend) + ) + with ctx: + result = compiled_fn(*cloned_args) + result.sum().backward() if not skip_check: self.assertEqual( @@ -442,6 +469,89 @@ def fn(x): # rand decomps do not have have numerical results as eager self._validate(fn, backend, x, skip_check=True) + @torch._functorch.config.patch(recompute_views=True) + @torch._inductor.config.patch(fx_graph_cache=False) + def test_tags_must_save_tensor_that_has_backward_hook(self): + def my_post_forward_hook(submod, args, output): + output.register_hook(my_backward_hook) + return output + + def my_backward_hook(grad): + return grad + + class MySubmod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.matmul(x, x) + z = y * y + return z + + class MyMod(torch.nn.Module): + def __init__(self): + super().__init__() + self.submod = MySubmod() + self.norm = torch.nn.LayerNorm(4) + + def forward(self, x): + out = torch.utils.checkpoint.checkpoint( + self.submod, x, use_reentrant=False + ) + norm_out = self.norm(out) + return norm_out + + def _factory_fn(): + mod = MyMod() + x = torch.ones(4, 4, dtype=torch.float32, requires_grad=True) + backend = "inductor" + return mod, x, backend + + mod_no_hook, x, backend = _factory_fn() + mod_no_hook_fwd_outputs = set() + + with torch._inductor.config.patch( + post_grad_custom_pre_pass=functools.partial( + collect_fwd_graph_outputs, fwd_outputs=mod_no_hook_fwd_outputs + ) + ): + self._validate( + mod_no_hook, backend, x, fullgraph=True, compiled_autograd=True + ) + + mod_with_hook, x, backend = _factory_fn() + mod_with_hook.submod.register_forward_hook(my_post_forward_hook) + mod_with_hook_fwd_outputs = set() + + with torch._inductor.config.patch( + post_grad_custom_pre_pass=functools.partial( + collect_fwd_graph_outputs, fwd_outputs=mod_with_hook_fwd_outputs + ) + ): + self._validate( + mod_with_hook, backend, x, fullgraph=True, compiled_autograd=True + ) + + # If `z` has a backward hook, result of `z = y * y` should also be saved in addition to the usual saved tensors. + mod_no_hook_fwd_outputs_no_primal = { + x for x in mod_no_hook_fwd_outputs if not x.startswith("primals_") + } + mod_with_hook_fwd_outputs_no_primal = { + x for x in mod_with_hook_fwd_outputs if not x.startswith("primals_") + } + additional_saved_tensors = ( + mod_with_hook_fwd_outputs_no_primal - mod_no_hook_fwd_outputs_no_primal + ) + expected_additional_saved_tensors = {"mul"} + self.assertEqual( + additional_saved_tensors, + expected_additional_saved_tensors, + f""" +Expected additional saved tensors: {expected_additional_saved_tensors} but got: {additional_saved_tensors}. +Non-primal fwd outputs from model w/ backward hook: {mod_with_hook_fwd_outputs_no_primal}. +Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no_primal}.""", + ) + @requires_cuda def test_fallback(self): def gn(x, y): diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 4785a7f96cd041..629e21e0daf948 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -439,6 +439,29 @@ def f(x): self.assertEqual(result, Foo.apply(x)) self.assertEqual(cnt.frame_count, 1) + def test_fwd_no_grad(self): + # autograd.Function.forward should be traced and called under no_grad mode. + # torch.exp with out=... arguments don't support automatic differentiation, + # so can't be traced/called under grad mode (throwing RuntimeError), + # therefore this unit test ensures fwd is under no_grad mode. + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs): + torch.exp(inputs, out=inputs) + return inputs + + @staticmethod + def backward(ctx, grad_output): + return None + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + return Foo.apply(x) + + x1 = torch.randn(2, 3, requires_grad=True) + x2 = x1.clone() + self.assertEqual(f(x1), Foo.apply(x2)) + def test_amp_custom_fwd_bwd(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() @@ -542,14 +565,18 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: " function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None fwd_body_0 = self.fwd_body_0 bwd_body_0 = self.bwd_body_0 - autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True]); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None + autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None return (autograd_function_apply,) class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + mul: "f32[]" = l_weird_b * l_weird_c clone: "f32[]" = x.clone(); x = None mul_1: "f32[]" = mul * clone; mul = clone = None + + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return (mul_1, [l_weird_b, l_weird_c]) class bwd_body_0(torch.nn.Module): @@ -699,20 +726,6 @@ def forward(self, x): # In the future, we should make the Dynamo test suite actually # run on test_autograd.py (it's disabled right now) and delete these. def test_smoke_from_test_autograd(self): - class Func(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - out0 = x.clone() - out1 = x.clone() - ctx.mark_non_differentiable(out1) - ctx._materialize_non_diff_grads = False - return out0, out1 - - @staticmethod - def backward(ctx, g0, g1): - assert g1 is None - return g0 - def mult1(x): return x.prod(dim=-1).prod(dim=-1) @@ -838,10 +851,6 @@ def backward(ctx, grad): return grad, None def test(): - a = torch.tensor(1.0, requires_grad=True) - out = Func.apply(a)[0] - out.backward() - x = torch.ones(2, 4, 4).requires_grad_() mult2(x) @@ -1076,6 +1085,119 @@ def foo(x): foo(torch.randn(2, requires_grad=True)) self.assertEqual(cnts.frame_count, 1) + def test_mark_non_differentiable(self): + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + from torch.autograd import Function + + class MyFunction(Function): + @staticmethod + def forward(ctx, x, y): + out1 = x.sin() + out2 = y * 2 + ctx.mark_non_differentiable(out2) + return out1, out2 + + @staticmethod + def backward(ctx, grad1, grad2): + return grad1.cos(), grad2 * 0.0 + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + return MyFunction.apply(x, y) + + x = torch.tensor(10.0, requires_grad=True) + y = torch.tensor(20.0, requires_grad=True) + ref1, ref2 = MyFunction.apply(x, y) + res1, res2 = fn(x, y) + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + # Ensure out1 requires gradients, out2 does not. + self.assertTrue(ref1.requires_grad) + self.assertTrue(res1.requires_grad) + self.assertFalse(ref2.requires_grad) + self.assertFalse(res2.requires_grad) + res1.sum().backward() + + # check Dynamo captured graph is correct! + actual_graph = torch._dynamo.testing.normalize_gm( + cnt.graphs[0].print_readable(print_output=False) + ) + self.assertExpectedInline( + actual_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[]", L_y_: "f32[]"): + l_x_ = L_x_ + l_y_ = L_y_ + + function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None + fwd_body_0 = self.fwd_body_0 + bwd_body_0 = self.bwd_body_0 + autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]); fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None + getitem: "f32[]" = autograd_function_apply[0] + getitem_1: "f32[]" = autograd_function_apply[1]; autograd_function_apply = None + return (getitem, getitem_1) + + class fwd_body_0(torch.nn.Module): + def forward(self, ctx, x: "f32[]", y: "f32[]"): + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + + out1: "f32[]" = x.sin(); x = None + + out2: "f32[]" = y * 2; y = None + + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None + return ((out1, out2), []) + + class bwd_body_0(torch.nn.Module): + def forward(self, ctx, grad1: "f32[]", grad2: "f32[]"): + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None + + cos: "f32[]" = grad1.cos(); grad1 = None + mul: "f32[]" = grad2 * 0.0; grad2 = None + + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None + return (cos, mul) +""", + ) + + def test_mark_multi_output_non_differentiable(self): + from torch.autograd import Function + + class MyFunction(Function): + @staticmethod + def forward(ctx, x, y, z): + out1 = x.sin() + out2 = y * 2 + out3 = z + 3 + ctx.mark_non_differentiable(out2, out3) + return out1, out2, out3 + + @staticmethod + def backward(ctx, grad1, grad2, grad3): + return grad1.cos(), grad2, grad3 + + @torch.compile(backend="aot_eager", fullgraph=True) + def fn(x, y, z): + return MyFunction.apply(x, y, z) + + x = torch.tensor(10.0, requires_grad=True) + y = torch.tensor(20.0, requires_grad=True) + z = torch.tensor(30.0, requires_grad=True) + ref1, ref2, ref3 = MyFunction.apply(x, y, z) + res1, res2, res3 = fn(x, y, z) + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + self.assertEqual(ref3, res3) + # Ensure out1 requires gradients, out2 does not. + self.assertTrue(ref1.requires_grad) + self.assertTrue(res1.requires_grad) + self.assertFalse(ref2.requires_grad) + self.assertFalse(res2.requires_grad) + self.assertFalse(ref3.requires_grad) + self.assertFalse(res3.requires_grad) + res1.sum().backward() + def test_default_values(self): from torch.autograd import Function diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index d2aacd15f5e3d0..bf386bbf164926 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -1,8 +1,11 @@ # Owner(s): ["module: dynamo"] +import sys import unittest +from unittest.mock import MagicMock, patch import torch import torch._dynamo +import torch._dynamo.backends import torch._dynamo.test_case from torch._dynamo.backends.debugging import ExplainWithBackend from torch._dynamo.backends.onnxrt import has_onnxruntime @@ -294,6 +297,55 @@ def f(x): opt_f(torch.randn(3, 3)) self.assertTrue(backend_run) + def test_lookup_custom_backend(self): + from torch._dynamo import list_backends + + backends_group = "torch_dynamo_backends" + name = "mycustombackend" + + mock_3_9 = MagicMock() + mock_3_9.load.return_value = lambda: "mocked 3.9" + mock_3_9.name = name + + mock_3_10 = MagicMock() + mock_3_10.load.return_value = lambda: "mocked 3.10" + + def mock_eps(group=None): + if sys.version_info < (3, 10): + return {backends_group: [mock_3_9]} + else: + assert group == backends_group, group + mock_group = MagicMock() + mock_group.names = [name] + mock_group[name] = mock_3_10 + # mock_group[name].load.return_value = lambda: "mocked 3.10" + return mock_group + + with patch("importlib.metadata.entry_points", mock_eps): + from torch._dynamo.backends import registry + + registry._lazy_import.cache_clear() + registry._discover_entrypoint_backends.cache_clear() + + backends = list_backends() + assert name in backends, (name, backends) + + def test_backend_recompilation(self): + def fn(x): + return x + x + + input = torch.tensor(2.0) + + opt_fn = torch.compile( + fn, backend="inductor", options={"_raise_error_for_testing": False} + ) + opt_fn(input) + with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): + opt_fn = torch.compile( + fn, backend="inductor", options={"_raise_error_for_testing": True} + ) + opt_fn(input) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_comptime.py b/test/dynamo/test_comptime.py index 61a3a088013ed3..28e8f15c737eb2 100644 --- a/test/dynamo/test_comptime.py +++ b/test/dynamo/test_comptime.py @@ -160,7 +160,7 @@ def f(x): self.assertExpectedInline( FILE.getvalue(), """\ -- TensorVariable() +- FakeTensor(..., size=(2,)) """, ) @@ -186,8 +186,8 @@ def _(ctx): self.assertExpectedInline( FILE.getvalue(), """\ -x = TensorVariable() -y = TensorVariable() +x = FakeTensor(..., size=(2,)) +y = FakeTensor(..., size=(2,)) """, ) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 1b654e7228aa89..239174df715874 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1533,6 +1533,117 @@ def fn(x): self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) self.assertEqual(cnts.frame_count, 2) + def test_sdpa_kernel_ctx_manager1(self): + modified_backend_state = [torch.nn.attention.SDPBackend.MATH] + + @torch._dynamo.allow_in_graph + def check_backend_state_is_modified(): + self.assertEqual( + torch.nn.attention._cur_sdpa_kernel_backends(), modified_backend_state + ) + + def f(x): + with torch.nn.attention.sdpa_kernel( + # pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + [torch.nn.attention.SDPBackend.MATH] + ): + output = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + + return output + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16)) + + def test_sdpa_kernel_ctx_manager2(self): + original_backend_state = set(torch.nn.attention._cur_sdpa_kernel_backends()) + modified_backend_state = [torch.nn.attention.SDPBackend.MATH] + + @torch._dynamo.allow_in_graph + def check_backend_state_is_original(): + self.assertEqual( + set(torch.nn.attention._cur_sdpa_kernel_backends()), + original_backend_state, + ) + + @torch._dynamo.allow_in_graph + def check_backend_state_is_modified(): + self.assertEqual( + torch.nn.attention._cur_sdpa_kernel_backends(), modified_backend_state + ) + + def g(x): + torch._dynamo.graph_break() + output = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + return output + + def f(x): + check_backend_state_is_original() + with torch.nn.attention.sdpa_kernel( + # pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + [torch.nn.attention.SDPBackend.MATH] + ): + output1 = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + + # graph break + output2 = g(x) + + output3 = torch.nn.functional.scaled_dot_product_attention(x, x, x).to( + torch.float32 + ) + check_backend_state_is_modified() + + check_backend_state_is_original() + + return output1 + output2 + output3 + + cnts = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnts) + opt_f(torch.randn(2, 2, 2, 2).to(dtype=torch.float16)) + self.assertEqual(cnts.frame_count, 3) + + # test sdpa_kernel graph break with 2 arguments + def test_sdpa_kernel_ctx_manager3(self): + modified_backend_state = { + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + } + + @torch._dynamo.allow_in_graph + def check_backend_state_is_modified(): + self.assertEqual( + set(torch.nn.attention._cur_sdpa_kernel_backends()), + modified_backend_state, + ) + + def f(x): + with torch.nn.attention.sdpa_kernel( + # pyre-fixme[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. + [ + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + ] + ): + # FLASH_ATTENTION may not be supported, but we're not actually + # doing any sdpa + x = x + 1 + torch._dynamo.graph_break() + check_backend_state_is_modified() + x = x + 1 + + return x + + opt_f = torch.compile(f, backend="eager") + opt_f(torch.randn(2, 2)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 463634e3216358..0472702fadca61 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -184,20 +184,6 @@ def hook(module, args): all(node.target is not torch.sigmoid for node in gm1.graph.nodes) ) - def test_disable_no_recompile(self): - def gn(x): - return torch.cos(x) - - @torch.compile(backend="eager") - def fn(x): - x = torch.sin(x) - x = torch._dynamo.disable(gn, recursive=True)(x) - return torch.sin(x) - - with torch._dynamo.config.patch(error_on_recompile=True): - for _ in range(5): - fn(torch.randn(4)) - def test_allow_in_graph(self): cnts = torch._dynamo.testing.CompileCounter() @@ -611,6 +597,33 @@ def fn(x, y): self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_mark_static_nn_module(self): + @torch._dynamo.mark_static + class Mock(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x): + return x * self.c + + cnts = torch._dynamo.testing.CompileCounter() + mod1 = Mock(10) + mod2 = Mock(20) + mod3 = Mock(30) + opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True) + opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True) + opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True) + + x = torch.randn(4, 4) + opt_mod1(x) + opt_mod2(x) + opt_mod3(x) + + # Must be 3 compilations. If not marked static there would be 2, because self.c would be converted to symints. + self.assertEqual(cnts.frame_count, 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index c729cb1d3f1b2c..0f137776b11a71 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -10,7 +10,12 @@ from torch._dynamo.comptime import comptime from torch._dynamo.exc import Unsupported from torch.testing._internal.common_device_type import skipIf -from torch.testing._internal.common_utils import IS_FBCODE, munge_exc, TEST_Z3 +from torch.testing._internal.common_utils import ( + IS_FBCODE, + munge_exc, + skipIfWindows, + TEST_Z3, +) from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test @@ -102,7 +107,7 @@ def f(ctx): Traceback (most recent call last): File "test_exc.py", line N, in f raise NotImplementedError -torch._dynamo.exc.InternalTorchDynamoError: +torch._dynamo.exc.InternalTorchDynamoError: NotImplementedError: from user code: File "test_exc.py", line N, in fn001 @@ -200,6 +205,10 @@ def fn001(x): translation_validation=True, translation_validation_no_bisect=True, ) + @skipIfWindows( + msg='AssertionError: "tran[551 chars]s1 s2 s3) s0)\n ==> (<= (+ s1 s2) (+ s0 (* -1[511 chars][0])' # noqa: PLR0133 + != 'tran[551 chars]s1 s2) (+ s0 (* -1 s3)))\n ==> (<= (+ s1 s2) [483 chars][0])"' + ) def test_trigger_on_error(self): from torch.fx.experimental.validator import ValidationException diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index cd95d98abf30ae..d6613d84560af0 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -4,6 +4,7 @@ import torch._dynamo.config import torch._dynamo.test_case import torch._functorch.config +import torch.nn import torch.utils.checkpoint @@ -267,6 +268,29 @@ def forward(self, x): x = torch.ones(4) self.assertEqual(mod(x), opt_mod(x)) + def test_attribute_error_from_getattr(self): + class Mock: + def __init__(self): + self.a = 5 + + def __getattr__(self, name): + if name != "a": + raise AttributeError("missing") + return self.__dict__["a"] + + mock = Mock() + + def fn(x): + if hasattr(mock, "b"): + return torch.cos(x) + return torch.sin(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_stop_iteration(self): def zip_longest(*iterables, fillvalue=None): # Get the iterators for each iterable @@ -294,6 +318,21 @@ def fn(x, y): res = opt_fn(x, y) self.assertEqual(ref, res) + def test_nn_reraise(self): + class M(torch.nn.Module): + def forward(self, x): + raise ValueError("woof") + return x + 2 + + m = M() + m.register_forward_pre_hook(lambda m, go: None) + + torch._dynamo.utils.clear_compilation_metrics() + opt_call = torch.compile(lambda x: m(x), backend="eager") + self.assertRaises(ValueError, lambda: opt_call(torch.randn(3))) + metrics = torch._dynamo.utils.get_compilation_metrics() + self.assertEqual(metrics[0].fail_reason, "Observed exception") + def test_key_error(self): def fn(x, d): try: diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index e0ee9fab6d1375..03d40e377335b4 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3048,23 +3048,6 @@ def forward(self, *args): res = gm(input_tensor, input_tensor2) self.assertTrue(torch._dynamo.utils.same(ref, res)) - def test_export_mark_dynamic_conflict_dynamic_dim(self): - y = torch.randn([3, 3, 3]) - - def my_dyn_fn(x): - if x.shape[0] > 3: - return x.sin() - return x.cos() - - torch._dynamo.mark_dynamic(y, 0) - with self.assertRaisesRegex( - RuntimeError, - "Constraints violated", - ): - torch._dynamo.export( - my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},) - )(y) - def test_export_dynamic_dim_cleanup(self): y = torch.randn([3, 3, 3]) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 67ab47dfac9093..d6879488c4fd24 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -28,6 +28,7 @@ from torch._dynamo.variables import ConstantVariable from torch._dynamo.variables.lists import RangeVariable from torch.nn import functional as F +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( disable_translation_validation_if_dynamic_shapes, instantiate_parametrized_tests, @@ -156,6 +157,64 @@ def test_is_not_null(a, b): if a is not None and b is not None: return a + b + def test_foreach_lerp_(self): + def fn(x, y, s): + return torch._foreach_lerp_(x, y, s) + + cnt = torch._dynamo.testing.CompileCounter() + + fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) + expected = fn( + [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], + [torch.ones(2, 2), torch.ones(2, 2)], + torch.tensor(0.5), + ) + + actual = fn_opt( + [torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14], + [torch.ones(2, 2), torch.ones(2, 2)], + torch.tensor(0.5), + ) + self.assertTrue(same(expected, actual)) + + def test_broadcast_foreach_pow(self): + from torch._dynamo.utils import same + + def fn(x, y): + return torch._foreach_pow(x, y) + + cnt = torch._dynamo.testing.CompileCounter() + + fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) + inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)]) + + actual = fn_opt(*inps) + expected = fn(*inps) + self.assertTrue(same(actual, expected)) + self.assertTrue(cnt.frame_count, 1) + + def test_addcmul_(self): + from copy import deepcopy + + from torch._dynamo.utils import same + + def fn(x, y, z, s): + return x.addcmul_(y, z, value=s) + + cnt = torch._dynamo.testing.CompileCounter() + fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn) + inps = ( + torch.ones(2, 2), + torch.ones(2, 2) + 1, + torch.rand(2, 2), + torch.tensor(0.3), + ) + inps_2 = deepcopy(inps) + actual = fn_opt(*inps) + expected = fn(*inps_2) + self.assertTrue(same(actual, expected)) + self.assertEqual(cnt.frame_count, 1) + @make_test def test_functools_partial(a, b): return clip01(a + b) @@ -181,6 +240,22 @@ def test_itertools_chain_from_iterable(a, b): v = v + x return v + def test_itertools_reconstruct(self): + def fn(a): + it1 = itertools.repeat(1) + it2 = itertools.count(2) + for _ in range(3): + a += next(it1) + a += next(it2) + return it1, it2, a + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + i1, i2, a = fn(torch.ones(3, 3)) + it1, it2, b = opt_fn(torch.ones(3, 3)) + self.assertEqual(next(i1), next(it1)) + self.assertEqual(next(i2), next(it2)) + self.assertEqual(a, b) + @make_test def test_obj_eq(a, b): v = a + b @@ -240,6 +315,17 @@ def test_itertools_combinations(a, b): combs.append(torch.ones(size)) return combs + @unittest.skipIf( + sys.version_info < (3, 10), + "itertools.pairwise was added at Python 3.10", + ) + @make_test + def test_itertools_pairwise(a): + pairs = [] + for size in itertools.pairwise((1, 2, 3, 4)): + pairs.append(torch.ones(size)) + return pairs + @make_test def test_np_iinfo(a): max_dim = np.iinfo(np.int16).max @@ -449,8 +535,7 @@ def test_deque(a, b): empty = collections.deque() d.extend(empty) - # dynamo same() util doesn't support deque so just return a list - return list(d) + return d @make_test def test_slice1(a): @@ -614,6 +699,12 @@ def test_dict_setdefault3(x): else: return x - 1 + @make_test + def test_dict_update_kwargs(x): + d = {"a": 2} + d.update(b=4) + return x * d["a"] * d["b"] + @make_test def test_defaultdict_setdefault1(x): d = collections.defaultdict.fromkeys("a", "b") @@ -1683,6 +1774,22 @@ def test_dict_sorted(x): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) + def test_dict_hasattr(self): + def fn(x): + if hasattr(x, "to"): + return x.to("cpu") + if hasattr(x, "items"): + return torch.cos(x["a"]) + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = dict(a=torch.randn(3)) + self.assertEqual(fn(x), opt_fn(x)) + + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] @@ -2072,6 +2179,13 @@ def test_in_not_in(x): assert 6 not in myotherlist return sum(mylist) + @make_test + def test_are_functorch_transforms_active(x): + if torch._C._are_functorch_transforms_active(): + return x + 1 + else: + return x - 1 + @make_test def test_partials_udf_kwarg(x): par_mul = functools.partial(udf_mul, y=torch.ones(10, 10)) @@ -3034,6 +3148,199 @@ def fn(a, ind, val): fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) ) + def test_map_return(self): + def fn(a, b): + return map(lambda x: x + 1, [a, b]) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.randn(3, 3), torch.randn(3, 3)) + self.assertIsInstance(m, map) + + @make_test + def test_map_max(a, b): + return max(map(lambda x: x.sum(), [a, b])) + + # max(map(...)) graph breaks + @unittest.expectedFailure + @make_test + def test_map_max_const(a): + return max(map(lambda x: x, [1, 2, 3])), a + 1 + + @make_test + def test_map_list(a, b): + return list(map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_tuple(a, b): + return tuple(map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_iter(a, b): + it = iter(map(lambda x: x + 1, [a, b])) + return next(it) + + @make_test + def test_map_zip_dict(a): + d = dict( + zip( + map(lambda x: x + 1, [0, 1, 2]), + [map(lambda x: x - 1, [y]) for y in [3, 4, 5]], + ) + ) + return list(d[3])[0], a + 1 # noqa: RUF015 + + @make_test + def test_map_dict_fromkeys(a): + return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1 + + @make_test + def test_map_set(a): + return set(map(lambda x: x + 1, [0, 1])), a + 1 + + # test_map_sum defined earlier + + @make_test + def test_map_reduce(a, b): + return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_sorted(a): + return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1 + + @make_test + def test_map_list_extend(a, b, c): + l = [a] + l.extend(map(lambda x: x + 1, [b, c])) + return l + + @make_test + def test_map_list_slice_assign(a, b, c, d, e): + l = [a, b, c] + l[1:2] = map(lambda x: x + 1, [d, e]) + return l + + @make_test + def test_map_deque_extendleft(a, b, c): + d = collections.deque([a]) + d.extendleft(map(lambda x: x + 1, [b, c])) + return d + + @make_test + def test_map_str_join(a): + return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1 + + def test_map_with_graph_break(self): + def f(a): + a += 1 + + def g(x): + nonlocal a + a += 1 + return x + 1 + + m = map(g, [1, 2, 3, 4, 5]) + a += next(m) # won't graph break + torch._dynamo.graph_break() + a += next(m) # will graph break + return a + + cnts = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=cnts) + self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) + self.assertEqual(cnts.frame_count, 3) + + def test_map_reconstruct(self): + def fn(a): + return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.ones(3, 3))[0] + self.assertIsInstance(m, map) + self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) + + def test_zip_reconstruct(self): + def fn(a): + return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1 + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + m = opt_fn(torch.ones(3, 3))[0] + self.assertIsInstance(m, zip) + self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) + + @make_test + def test_map_partial_unpack(a, b): + y = 1 + + def f(x): + nonlocal y + y += 1 + return x + + l = list(zip([a, b], map(f, [1, 2, 3, 4]))) + return a + y + + @make_test + def test_map_call_function_ex(a, b): + def f(x, y): + return x + y + + return f(*map(lambda x: x + 1, [a, b])) + + @make_test + def test_map_unpack_twice(a, b): + m = map(lambda x: x + 1, [a, b]) + l1 = list(m) + l2 = list(m) + return l1, l2 + + @make_test + def test_enumerate(a, b): + return list(enumerate([a, b], start=1)), a + 1 + + @make_test + def test_map_enumerate(a, b): + return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1 + + @make_test + def test_map_infinite(a, b): + return list(map(lambda x, y: x + y, [a, b], itertools.count(3))) + + @make_test + def test_map_unpack_vars(a, b): + x, y = map(lambda x: x + 1, [a, b]) + return x + y + + def test_enumerate_custom(self): + class MyClass: + def __iter__(self): + self.a = 1 + return self + + def __next__(self): + if self.a > 3: + raise StopIteration + self.a += 1 + return self.a + + def fn(x): + for i, it in enumerate(MyClass()): + x += i + it + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3))) + + def test_enumerate_reconstruct(self): + def fn(a, b): + return enumerate([a, b], start=1) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + inps = (torch.randn(3, 3), torch.randn(3, 3)) + it1 = fn(*inps) + it2 = opt_fn(*inps) + self.assertIsInstance(it2, enumerate) + self.assertEqual(list(it1), list(it2)) + def udf_mul(x, y): return x * y @@ -3313,6 +3620,71 @@ def fn(x): ref = opt_fn(x) self.assertEqual(ref, res) + def test_frozenset_construction(self): + def fn(x): + s = frozenset({x}) + t = frozenset(s) + return len(t) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + res = fn(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + + def test_frozenset_reconstruction(self): + d = {} + f = frozenset() + d[f] = torch.randn(4) + + def fn(x): + k = frozenset() + torch._dynamo.graph_break() + return d[k] * x + + opt_fn = torch.compile(fn, backend="eager") + x = torch.randn(4) + res = fn(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + + def test_frozenset_illegal_call_method(self): + def fn_add(): + s = frozenset((1, 2, 3)) + s.add({2}) + return len(s) + + def fn_pop(): + s = frozenset((1, 2, 3)) + s.pop() + return len(s) + + def fn_update(): + s = frozenset((1, 2, 3)) + s.update({4, 5, 6}) + return len(s) + + def fn_remove(): + s = frozenset((1, 2, 3)) + s.remove(2) + return len(s) + + def fn_discard(): + s = frozenset((1, 2, 3)) + s.discard(2) + return len(s) + + def fn_clear(): + s = frozenset((1, 2, 3)) + s.clear() + return len(s) + + for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]: + torch._dynamo.reset() + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError): + opt_fn() + def test_is_tensor_tensor(self): def fn(x, y): if x is y: @@ -3524,10 +3896,36 @@ def fn(x, ys, zs): with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): nopython_fn(x, ys[:1], zs) + with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): + nopython_fn(x, ys, zs[:1]) + # Should cause fallback if allow graph break with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys[:1], zs) + with self.assertRaisesRegex(ValueError, "zip()"): + opt_fn(x, ys, zs[:1]) + + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + def test_cuda_current_device(self): + def fn(x): + y = torch.empty( + (2, 3), dtype=torch.float32, device=torch.cuda.current_device() + ) + y.copy_(x) + return torch.sin(y + y.device.index) + + counter = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(backend=counter, fullgraph=True)(fn) + + with torch.cuda.device(0): + x = torch.randn(2, 3) + self.assertEqual(opt_fn(x), fn(x)) + self.assertEqual(counter.frame_count, 1) + with torch.cuda.device(1): + self.assertEqual(opt_fn(x), fn(x)) + self.assertEqual(counter.frame_count, 2) + def test_fn_with_attr(self): def fn(x): if fn.pred: @@ -3586,6 +3984,22 @@ def foo_default_str(x): dynamo_class_name = dynamo_default_str[1].split(" object at")[0] self.assertEqual(eager_class_name, dynamo_class_name) + def test_pybind_object(self): + def fn(x, pybind_obj): + if pybind_obj.result: + return torch.cos(x) + return torch.sin(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0) + x = torch.randn(4) + self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) + + pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1) + x = torch.randn(4) + self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj)) + instantiate_parametrized_tests(FunctionTests) diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index d7e85014ef3657..45f8cbe9690308 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import functools +import unittest import weakref import torch @@ -373,6 +374,14 @@ def test_weakref_alive_guard(self): del x self.assertFalse(guard(weakref_x())) + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_call_function_no_args_guard(self): + x = torch.cuda.current_device() + guard = guards.EQUALS_MATCH(x, [0]) + self.assertTrue(guard(0)) + self.assertFalse(guard(1)) + self.assertFalse(guard(2)) + def test_guard_manager_leaf_guard(self): guard_manager = RootGuardManager() guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"]) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 2ec9b7f5b681ad..bb109448d28482 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -23,6 +23,7 @@ normalize_gm, ) from torch._dynamo.utils import counters, ifdynstaticdefault +from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.wrap import wrap from torch.testing._internal.common_utils import ( munge_exc, @@ -2502,6 +2503,139 @@ def fn(pred, pytree_in): ): torch.compile(fn, backend="eager")(pred, pytree_in) + def test_hints_wrapper(self): + def ref_fn(x, y): + x = x + y + x = torch.relu(x) + x = x + y + return torch.abs(x) + + def fn_with_hints(x, y): + x = x + y + + def inner_body_fn(x, y): + x = torch.relu(x) + x = x + y + return x + + def outer_body_fn(x, y): + x = hints_wrapper(inner_body_fn, (x, y), {}, hints={"inner_body": True}) + x = torch.abs(x) + return x + + res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"outer_body": True}) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + eager_res = fn_with_hints(x, y) + compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) + ref_res = ref_fn(x, y) + self.assertEqual(eager_res, ref_res) + self.assertEqual(compiled_res, ref_res) + self.assertEqual(len(cnt.graphs), 1) + + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + graph = backend.graphs[0] + self.assertExpectedInline( + normalize_gm(graph.print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[2, 4]", L_y_: "f32[4]"): + l_x_ = L_x_ + l_y_ = L_y_ + + x: "f32[2, 4]" = l_x_ + l_y_; l_x_ = None + + hints_wrapper_body_1 = self.hints_wrapper_body_1 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_1, (x, l_y_), {}, hints = {'outer_body': True}); hints_wrapper_body_1 = x = l_y_ = None + res: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + return (res,) + + class hints_wrapper_body_1(torch.nn.Module): + def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"): + hints_wrapper_body_0 = self.hints_wrapper_body_0 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_0, (x, l_y_), {}, hints = {'inner_body': True}); hints_wrapper_body_0 = x = l_y_ = None + x_1: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + + x_2: "f32[2, 4]" = torch.abs(x_1); x_1 = None + return (x_2,) + + class hints_wrapper_body_0(torch.nn.Module): + def forward(self, x: "f32[2, 4]", l_y_: "f32[4]"): + x_1: "f32[2, 4]" = torch.relu(x); x = None + + x_2: "f32[2, 4]" = x_1 + l_y_; x_1 = l_y_ = None + return (x_2,) +""", + ) + + def test_hints_wrapper_no_hints(self): + def fn_with_hints(x, y): + def outer_body_fn(x, y): + x = torch.add(x, y) + return x + + res = hints_wrapper(outer_body_fn, (x, y), {}) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + msg = "hints_wrapper - key hints not provided" + with self.assertRaisesRegex(RuntimeError, msg): + compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) + + def test_hints_wrapper_incorrect_type(self): + def fn_with_hints(x, y): + def outer_body_fn(x, y): + x = torch.add(x, y) + return x + + res = hints_wrapper(outer_body_fn, (x, y), {}, hints={"test": (True,)}) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + msg = r"hints must be a dict containing int, float, bool or str value," + with self.assertRaisesRegex(RuntimeError, msg): + compiled_res = torch.compile(fn_with_hints, backend=cnt)(x, y) + + def test_hints_wrapper_pytree_inputs(self): + def fn_with_hints(x, y): + def outer_body_fn(x): + res = torch.add(x[0], x[1]["test"]) + return res + + res = hints_wrapper( + outer_body_fn, ((x, {"test": y}),), {}, hints={"test": True} + ) + return res + + backend = EagerAndRecordGraphs() + cnt = CompileCounterWithBackend(backend) + + x = torch.randn(2, 4) + y = torch.ones(4) + + msg = r"args must be a tuple of tensors, ints, floats, or bools," + with self.assertRaisesRegex(RuntimeError, msg): + fn_with_hints(x, y) + class HigherOrderOpVmapGuardTests(LoggingTestCase): @make_logging_test(recompiles=True) @@ -6307,6 +6441,9 @@ class _FallthroughTestOnly(torch._ops.HigherOrderOperator): def __init__(self): super().__init__("_fallthrough_test_only") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + test_op = _FallthroughTestOnly() default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS self.assertTrue( diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index a43951cc719f31..fe64ac745545f1 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -9,7 +9,8 @@ import torch._dynamo.test_case import torch._dynamo.testing import torch.distributed as dist -from torch._dynamo.testing import skipIfNotPy311 +from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311 +from torch._dynamo.trace_rules import _as_posix_path from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_utils import ( find_free_port, @@ -188,15 +189,6 @@ def throw(x): Traceback (most recent call last): File "test_logging.py", line N, in throw raise AssertionError -torch._inductor.exc.LoweringException: AssertionError: - target: aten.round.default - args[0]: TensorBox(StorageBox( - InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1])) - )) - -The above exception was the direct cause of the following exception: - -Traceback (most recent call last): torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AssertionError: target: aten.round.default @@ -668,10 +660,18 @@ def fn(a): def test_logs_out(self): import tempfile - with tempfile.NamedTemporaryFile() as tmp: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + file_path = _as_posix_path(tmp.name) + """ + NamedTemporaryFile will include a file open operation. + On Windowsm the file is opened by NamedTemporaryFile, the + following run_process_no_exception can't access a opened file. + And then, raise a PermissionError: [Errno 13] Permission denied: [file_path] + """ + tmp.close() env = dict(os.environ) env["TORCH_LOGS"] = "dynamo" - env["TORCH_LOGS_OUT"] = tmp.name + env["TORCH_LOGS_OUT"] = file_path stdout, stderr = self.run_process_no_exception( """\ import torch @@ -683,9 +683,18 @@ def fn(a): """, env=env, ) - with open(tmp.name) as fd: + with open( + file_path, encoding="utf-8" + ) as fd: # encoding file to UTF-8 for Windows. lines = fd.read() - self.assertEqual(lines, stderr.decode("utf-8")) + fd.close() + os.remove( + file_path + ) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False. + self.assertEqual( # process wrap difference: /r/n on Windows, /n on posix. + empty_line_normalizer(lines), + empty_line_normalizer(stderr.decode("utf-8")), + ) # single record tests @@ -723,6 +732,7 @@ def fn(a): "trace_shape_events", "cudagraph_static_inputs", "benchmarking", + "loop_ordering", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dc9e99f770111f..e546463059f0c8 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -44,11 +44,12 @@ CompileCounter, CompileCounterWithBackend, expectedFailureDynamic, + requiresPy310, same, skipIfNotPy311, unsupported, ) -from torch._dynamo.utils import CompileProfiler, counters, ifdynstaticdefault +from torch._dynamo.utils import counters, ifdynstaticdefault from torch._inductor.utils import run_and_get_code from torch.ao.quantization import MinMaxObserver from torch.ao.quantization.fake_quantize import FakeQuantize @@ -112,16 +113,6 @@ def wrapper(*args, **kwargs): return wrapper -def cleanup_op(opname): - ns, name = opname.split("::") - if not hasattr(torch.ops, ns): - return - actual_ns = getattr(torch.ops, ns) - if not hasattr(actual_ns, name): - return - delattr(actual_ns, name) - - class MyPickledModule(torch.nn.Module): def __init__(self, z): super().__init__() @@ -317,6 +308,19 @@ def fn(x): "Graph break for an optree C/C++ function optree._C.PyCapsule.flatten. Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py", ) + def test_scalar_device_movement(self): + if not torch._dynamo.config.assume_static_by_default: + self.skipTest("Doesn't work with symints") + + def add_fn(a, b, out): + res = torch.add(a, b, out=out) + return res + + res = add_fn(2, 3, torch.tensor(0.0)) + add_fn = torch.compile(add_fn, backend="eager", fullgraph=True) + res_compiled = add_fn(2, 3, torch.tensor(0.0)) + self.assertEqual(res, res_compiled) + @skipIfNNModuleInlined("fails internal CI") @unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode") def test_cpp_extension_recommends_custom_ops(self): @@ -509,8 +513,7 @@ def fn(x): @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) def test_pt2_compliant_ops_are_allowed(self): - lib = torch.library.Library("mylib", "FRAGMENT") - try: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::bar", "(Tensor x) -> Tensor", @@ -538,14 +541,10 @@ def g(x): optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) _ = optimized_g(x) - finally: - cleanup_op("mylib::bar") - del lib @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) def test_non_pt2_compliant_ops_graph_break(self): - lib = torch.library.Library("mylib", "FRAGMENT") - try: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib) torch.library.impl( "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib @@ -574,14 +573,10 @@ def g(x): ): optimized_g = torch._dynamo.optimize(counts, nopython=True)(f) y = optimized_g(x) - finally: - cleanup_op("mylib::bar2") - del lib @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True) def test_pt2_compliant_overload(self): - lib = torch.library.Library("mylib", "FRAGMENT") - try: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::bar3.tensor", "(Tensor x) -> Tensor", @@ -630,70 +625,6 @@ def h(x): with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"): y = optimized_h(x) - finally: - cleanup_op("mylib::bar3") - del lib - - def test_auto_functionalize_can_with_default(self): - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - def foo_impl(a, b, c=None, d=None, e=-1): - a + b - return - - def f(a, mode): - return torch.ops.mylib.foo( - a, - 0, - ) - - a = torch.tensor([10, 10, 10], dtype=torch.int64) - - torch.compile(f)(a, 0) - - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_can_with_none_return(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - lib.define("foo(Tensor x, Tensor(a!) out) -> None") - - def foo_impl(x, out): - out.copy_(x) - - lib.impl("foo", foo_impl, "CompositeExplicitAutograd") - x = torch.randn(3) - out = torch.zeros(3) - - @torch.compile - def f(x, out): - torch.ops.mylib.foo(x, out) - - f(x, out) - - def test_auto_functionalize_self_as_mutate_arg(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - lib.define("foo(Tensor(a!) self) -> None") - - def foo_impl(self: torch.Tensor) -> None: - self.sin_() - - x = torch.randn(3) - lib.impl("foo", foo_impl, "CompositeExplicitAutograd") - - @torch.compile(backend="inductor", fullgraph=True) - def f(x): - torch.ops.mylib.foo(x) - - f(x) - def test_user_defined_setattr1(self): @torch.compile(backend="eager", fullgraph=True) def fn(obj): @@ -742,8 +673,7 @@ def fn(x, other_fn): self.assertEqual(cnt.frame_count, 2) def test_generate_trivial_abstract_impl(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( "mylib::foo", "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()", @@ -768,291 +698,6 @@ def f(x, y, z, w): output = torch.compile(f, backend="eager", fullgraph=True)(*args) self.assertEqual(output, None) - finally: - cleanup_op("mylib::foo") - del lib - - def test_can_auto_functionalize(self): - from torch._higher_order_ops.auto_functionalize import can_auto_functionalize - - expected_true = [ - "(Tensor(a!) x) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", - "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", - ] - expected_false = [ - "(Tensor x) -> ()", - "(Tensor(a) x) -> Tensor(a)", - "(Tensor(a!) x) -> Tensor(a!)", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", - "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", - "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", - ] - for schema in expected_true: - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define("mylib::a", schema, lib=lib) - self.assertTrue( - can_auto_functionalize(torch.ops.mylib.a.default), msg=schema - ) - self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) - finally: - cleanup_op("mylib::a") - del lib - for schema in expected_false: - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define("mylib::a", schema, lib=lib) - self.assertFalse( - can_auto_functionalize(torch.ops.mylib.a.default), msg=schema - ) - self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) - finally: - cleanup_op("mylib::a") - del lib - - def test_auto_functionalize(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y, z, w, n): - x.add_(y[0] + w) - z.add_(y[1] + n) - - def f(x, y, z, n): - torch.ops.mylib.foo(x, y, z, 2, n) - - x = torch.randn(3) - y = (torch.randn(3), torch.randn(3)) - z = torch.randn(3) - n = torch.randn(3) - orig_args = (x, y, z, n) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - - # Check the graph under static shapes - if torch._dynamo.config.assume_static_by_default: - self.assertExpectedInline( - post_grad_graphs, - """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None - return ()""", - ) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - f(*eager_args) - self.assertEqual(compiled_args, eager_args) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_with_returns(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y, z, w, n): - x.add_(y[0] + w) - z.add_(y[1] + n) - return y[0] + w, y[1] + n - - @torch.library.impl_abstract("mylib::foo", lib=lib) - def foo_abstract(x, y, z, w, n): - return y[0] + w, y[1] + n - - def f(x, y, z, n): - return torch.ops.mylib.foo(x, y, z, 2, n) - - x = torch.randn(3) - y = (torch.randn(3), torch.randn(3)) - z = torch.randn(3) - n = torch.randn(3) - orig_args = (x, y, z, n) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( - *compiled_args - ) - - if torch._dynamo.config.assume_static_by_default: - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - self.assertExpectedInline( - post_grad_graphs, - """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None - getitem_4: "f32[3][1]cpu" = foo_default[0] - getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None - return (getitem_4, getitem_5)""", - ) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - eager_out = f(*eager_args) - self.assertEqual(compiled_args, eager_args) - self.assertEqual(compiled_out, eager_out) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_on_view(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!) x) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x): - x_np = x.detach().numpy() # view - np.sin(x_np, out=x_np) - return - - x = torch.randn(3) - expected = x.sin() - torch.ops.mylib.foo(x) - assert torch.allclose(x, expected) - - @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) - def f(x): - x = x.clone() - y = x[:] - torch.ops.mylib.foo(y) - return x - - y = f(x) - self.assertEqual(y, x.sin()) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_optional(self): - try: - lib = torch.library.Library("mylib", "FRAGMENT") - torch.library.define( - "mylib::foo", - "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(x, y, z, w, n): - if x is not None: - x.add_(y[0] + w) - if z is not None: - z.add_(y[1] + n) - - def f(x, y, z, n): - torch.ops.mylib.foo(x, y, z, 2, n) - - x = None - y = (torch.randn(3), torch.randn(3)) - z = torch.randn(3) - n = torch.randn(3) - orig_args = (x, y, z, n) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - - if torch._dynamo.config.assume_static_by_default: - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() - self.assertExpectedInline( - post_grad_graphs, - """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None - return ()""", - ) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - f(*eager_args) - self.assertEqual(compiled_args, eager_args) - finally: - cleanup_op("mylib::foo") - del lib - - def test_auto_functionalize_tensorlist(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - torch.library.define( - "mylib::foo", - "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()", - tags=torch.Tag.pt2_compliant_tag, - lib=lib, - ) - - @torch.library.impl("mylib::foo", "cpu", lib=lib) - @torch._dynamo.disable - def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out): - for o in out: - o.copy_(all_gather_output) - - def f(all_gather_output, all_gather_input_split_sizes, dim, out): - torch.ops.mylib.foo( - all_gather_output, all_gather_input_split_sizes, dim, out - ) - - a = torch.ones(4) - b = [2, 3] - c = 0 - d = [torch.empty(4) for _ in range(2)] - orig_args = (a, b, c, d) - - compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) - - eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) - f(*eager_args) - self.assertEqual(compiled_args, eager_args) def test_shape_int_inplace_binops(self): def fn(x): @@ -1616,6 +1261,22 @@ def _fn(a, b, func=func): expected_ops_dynamic=ifdynstaticdefault(1, 5), ) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_arange_length_with_float32_dtype(self): + @torch.compile(fullgraph=True) + def f(x): + y = x.item() + torch._check_is_size(y) + r = torch.arange(y, dtype=torch.float32) + + if r.size(0) == y: + return r + 1 + + return r + + x = torch.tensor([300]) + r = f(x) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_torch_check(self): cnts = torch._dynamo.testing.CompileCounter() @@ -2353,6 +2014,17 @@ def fn(cfg, x, y): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) + def test_data_access_in_inference_mode(self): + @torch.compile(fullgraph=True) + def f(x): + y = x.data + return y + + with torch.inference_mode(): + x = torch.randn(3) + y = f(x) + self.assertEqual(y, x) + def test_dataclass_fields(self): @dataclasses.dataclass class MyDataClass: @@ -3708,6 +3380,21 @@ def fn4(x) -> None: self.assertTrue(same(obj41.y, obj42.y)) self.assertEqual(cnts.frame_count, 1) + def test_thread_local_setattr(self): + from threading import local + + loc = local() + + @torch.compile(fullgraph=True) + def fn(x, l): + l.x = x + return x + 1 + + x = torch.ones(2, 2) + fn(x, loc) + + self.assertTrue(loc.x is x) + def test_user_defined_class_name(self): class MyClassFoo: pass @@ -8228,49 +7915,6 @@ def fn(a, b): fn(torch.rand(2, 3), torch.rand(2, 3)) fn(torch.rand(2, 3), (1, 2, 3)) - @expectedFailureDynamic - @torch._dynamo.config.patch(automatic_dynamic_shapes=False) - def test_compile_profiler(self): - class Model(torch.nn.Module): - def forward(self, input): - return input + input - - model = Model() - prof = CompileProfiler() - compiled = torch.compile(model, backend=prof) - base_checker = ( - lambda: FileCheck() - .check("Torchdynamo Profiler Report") - .check("Graph Breaks") - .check("No graph breaks detected.") - .check("Recompilation") - ) - input = torch.rand((2, 3, 4)) - _ = compiled(input) - base_checker().check("No recompilation detected.").run(prof.report()) - - new_shape_input = torch.rand((3, 3, 4)) - _ = compiled(new_shape_input) - - # Not an exhaustive test of dynamic shapes behavior, but some sanity - if torch._dynamo.config.assume_static_by_default: - base_checker().check("Recompile Reasons").check("'forward'").check( - "cache_size_limit to 1" - ).run(prof.report()) - else: - base_checker().check("No recompilation detected.").run(prof.report()) - - new_shape_input = torch.rand((4, 3, 4)) - _ = compiled(new_shape_input) - - base_checker().check("Recompile Reasons").check("'forward'").check( - "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" - ).check( - "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" - ).run( - prof.report() - ) - def test_guards_strip_function_call(self): from torch._dynamo.guards import strip_function_call @@ -8945,25 +8589,6 @@ def f(lengths, values): @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) - def test_unbacked_auto_functionalize_op(self): - @torch.library.custom_op( - "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] - ) - def mk_image(decoder: Tensor) -> Tensor: - return torch.randn(2, 3, 4, 5) - - @torch.library.register_fake("mylib::mk_image") - def _(decoder: Tensor) -> Tensor: - image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] - return torch.empty(image_size) - - @torch.compile(fullgraph=True) - def f(x): - return torch.ops.mylib.mk_image.default(x) - - x = torch.zeros(100, dtype=torch.int64) - f(x) - def test_out_variant_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: lib.define( @@ -10126,6 +9751,113 @@ def forward(self, x): } self.assertEqual(expected_fqn, gm.meta["dynamo_flat_name_to_original_fqn"]) + def test_proxy_frozen_dataclass(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + + def fn(x, y): + dc = TestDataClass(x, y) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + + def test_reconstruct_frozen_dataclass(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + + def fn(x, y): + dc = TestDataClass(x, y) + torch._dynamo.graph_break() + return dc.x + dc.y + + fn_opt = torch.compile()(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + def test_frozen_dataclass_default_value(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + z: int = dataclasses.field(default=5) + a: int = 6 + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + dc.z + dc.a + + def fn(x, y): + dc = TestDataClass(x, y) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + + def test_frozen_dataclass_default_factory(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + z: int = dataclasses.field(default_factory=list) + a: int = dataclasses.field(default_factory=lambda: [5]) + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + dc.a[0] + + def fn(x, y): + dc = TestDataClass(x, y) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + + @requiresPy310 + def test_frozen_dataclass_kw_only(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: torch.Tensor + y: torch.Tensor + z: int = dataclasses.field(kw_only=True) + a: int = dataclasses.field(kw_only=True) + + @allow_in_graph + def inner_fn(dc): + return dc.x + dc.y + dc.a + dc.z + + def fn(x, y): + dc = TestDataClass(x, y, z=5, a=2) + return inner_fn(dc) + + fn_opt = torch.compile(fullgraph=True)(fn) + inps = (torch.ones(2, 2), torch.ones(2, 2)) + actual = fn_opt(*inps) + expected = fn(*inps) + + self.assertEqual(actual, expected) + def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -10713,6 +10445,21 @@ def fn(x): c2 = _debug_get_cache_entry_list(fn.__code__) self.assertEqual(len(c2), 0) + def test_guard_size_oblivious_simplification(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + u0, u1 = x.tolist() + torch._check_is_size(u0) + torch._check_is_size(u1) + torch._check((2 * u0) % (u0 + u1) == 0) + torch._check((2 * u0) // (u0 + u1) != 0) + if guard_size_oblivious((2 * u0) // (u0 + u1) == 0): + return torch.tensor(True) + else: + return torch.tensor(False) + + fn(torch.tensor([3, 3])) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_guard_size_oblivious(self): # This code, in fact, does NOT work in eager @@ -11099,6 +10846,23 @@ def gn(x): self.assertEqual(bound0, bound1) + def test_inspect_signature_parameters(self): + import inspect + + def fn(x, gn): + d = inspect.signature(gn).parameters + if d["a"].default is inspect.Parameter.empty: + return torch.sin(x + 1) + else: + return torch.cos(x + 1) + + def gn(a: torch.Tensor, b: int) -> torch.Tensor: + return a + b + + x = torch.randn(2, 3) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + self.assertEqual(fn(x, gn), opt_fn(x, gn)) + def test_grad_none(self): def fn(x, y): x.grad = torch.abs(y) @@ -11503,6 +11267,72 @@ def fn(x): fn(torch.randn(4)) + def test_tuple_class(self): + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x): + updated_x = [] + for v in x: + updated_x.append(v + 1) + return x.__class__(updated_x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + + d1 = torch.zeros(2, 2) + d2 = torch.ones(2, 2) + + r = opt_fn((d1, d2)) + self.assertEqual(r.__class__, tuple) + r1, r2 = r + self.assertEqual(r1, torch.ones(2, 2)) + self.assertEqual(r2, torch.ones(2, 2) + 1) + self.assertEqual(cnts.frame_count, 1) + + def test_list_class(self): + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x): + updated_x = [] + for v in x: + updated_x.append(v + 1) + return x.__class__(updated_x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + + d1 = torch.zeros(2, 2) + d2 = torch.ones(2, 2) + + r = opt_fn([d1, d2]) + self.assertEqual(r.__class__, list) + self.assertEqual(len(r), 2) + self.assertEqual(r[0], torch.ones(2, 2)) + self.assertEqual(r[1], torch.ones(2, 2) + 1) + self.assertEqual(cnts.frame_count, 1) + + def test_namedtuple_class(self): + import collections + + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x): + updated_x = [] + for v in x: + updated_x.append(v + 1) + return x.__class__(*updated_x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + + d1 = torch.zeros(2, 2) + d2 = torch.ones(2, 2) + point = collections.namedtuple("Point", ["x", "y"]) + p = point(d1, d2) + + r = opt_fn(p) + self.assertEqual(r.__class__, point) + self.assertEqual(r.x, torch.ones(2, 2)) + self.assertEqual(r.y, torch.ones(2, 2) + 1) + self.assertEqual(cnts.frame_count, 1) + class TestTracer(JitTestCase): def test_jit_save(self): @@ -11533,6 +11363,59 @@ def forward(self, x): opt_fn() +class TestCustomFunction(torch.testing._internal.common_utils.TestCase): + def test_autograd_function_with_matmul_folding_at_output(self): + """ + When tensor folding occurs during matmul operation returned tensor is a view. + This can cause issues when matmul is used inside a custom function + and such view is then returned as output. Then it cannot be modified inplace + and causes errors. + It can be especially problematic when after such function inplace allreduce + is performed. This test recreates this behaviour. + Issue is resolved when unsafe_view is returned from matmul instead. + """ + + class CustomFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp1, inp2): + ctx.save_for_backward(inp2) + ctx.output_shape = inp1.size() + return torch.matmul(inp1, inp2) + + @staticmethod + def backward(ctx, grad_output): + output_shape = ctx.output_shape + (inp2,) = ctx.saved_tensors + return ( + torch.mm(grad_output.squeeze(), inp2.t()).view(output_shape), + None, + ) + + def outer_function(inp1, inp2): + res = CustomFunction.apply(inp1, inp2) + res.add_(1.0) + return res.sum() + + def usual_function(inp1, inp2) -> torch.Tensor: + res = torch.matmul(inp1, inp2) + res.add_(1.0) + return res.sum() + + inp1_custom = torch.randn(4, 1, 2, requires_grad=True) + inp1_usual = inp1_custom.detach().clone().requires_grad_(True) + + inp2 = torch.randn(2, 4) + c_custom_func = torch.compile(outer_function) + c_usual_func = torch.compile(usual_function) + + result_custom = c_custom_func(inp1_custom, inp2) + result_custom.backward() + result_usual = c_usual_func(inp1_usual, inp2) + result_usual.backward() + + torch.allclose(inp1_custom.grad, inp1_usual.grad) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index b6d6fd050670c7..e9eb78254ce067 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -6,6 +6,8 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import TEST_HPU, TestCase try: @@ -45,6 +47,21 @@ def fn(a, tmp): res = opt_fn(x, tmp) self.assertTrue(same(ref, res)) + @maybe_skip + def test_pretrained_non_const_attr(self): + def fn(a, tmp): + if tmp.pruned_heads: + return a + 1 + else: + return a - 1 + + x = torch.randn(2) + tmp = PretrainedConfig() + ref = fn(x, tmp) + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + res = opt_fn(x, tmp) + self.assertTrue(same(ref, res)) + class TestModelOutput(torch._dynamo.test_case.TestCase): @maybe_skip @@ -238,9 +255,34 @@ def fn(inp): self.assertEqual(fn(inp).attentions, opt_fn(inp).attentions) @maybe_skip - def test_HF_bert_model_output(self): - device = "cuda" if torch.cuda.is_available() else "cpu" + def test_none(self): + class Model(torch.nn.Module): + def forward(self, x): + x = x + 1 + return CausalLMOutputWithPast(loss=None, logits=x)[0] + + model = Model() + opt_model = torch.compile(model, backend="eager", fullgraph=True) + x = torch.randn(1, 1, 1, 1) + + self.assertTrue(same(model(x), opt_model(x))) + + @maybe_skip + def test_reconstruction(self): + class Model(torch.nn.Module): + def forward(self, x): + x = x + 1 + return CausalLMOutputWithPast(loss=x, logits=None) + + model = Model() + x = torch.randn(1, 1, 1, 1) + eo = torch._dynamo.export(Model(), aten_graph=True)(x) + self.assertTrue(same(model(x), eo.graph_module(x))) + +class TestModelOutputBert(TestCase): + @maybe_skip + def test_HF_bert_model_output(self, device): class BertPooler(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -316,31 +358,12 @@ def forward( torch.allclose(orig_result.pooler_output, compiled_result.pooler_output) ) - @maybe_skip - def test_none(self): - class Model(torch.nn.Module): - def forward(self, x): - x = x + 1 - return CausalLMOutputWithPast(loss=None, logits=x)[0] - - model = Model() - opt_model = torch.compile(model, backend="eager", fullgraph=True) - x = torch.randn(1, 1, 1, 1) - - self.assertTrue(same(model(x), opt_model(x))) - - @maybe_skip - def test_reconstruction(self): - class Model(torch.nn.Module): - def forward(self, x): - x = x + 1 - return CausalLMOutputWithPast(loss=x, logits=None) - model = Model() - x = torch.randn(1, 1, 1, 1) - eo = torch._dynamo.export(Model(), aten_graph=True)(x) - self.assertTrue(same(model(x), eo.graph_module(x))) +devices = ["cpu", "cuda"] +if TEST_HPU: + devices.append("hpu") +instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index ee8a9b579ff124..4d1f2bbea389ef 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -from unittest.mock import patch import torch import torch._dynamo.test_case @@ -14,6 +13,17 @@ from torch.utils._python_dispatch import TorchDispatchMode +class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + class TorchDispatchModeTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -57,9 +67,11 @@ def tearDownClass(cls): def setUp(self): torch.set_default_device(None) + torch._dynamo.reset() def tearDown(self): torch.set_default_device(None) + torch._dynamo.reset() def _run_torch_function_mode_guard_test(self): class TestMode1(BaseTorchFunctionMode): @@ -94,70 +106,6 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 4) - def _run_ignored_mode_types_test(self): - class IgnoredMode(BaseTorchFunctionMode): - pass - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt.__call__, fullgraph=True) - def fn(x): - return x + 1 - - inp = torch.ones(2, 2) - - with patch( - "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode} - ): - # initial compile - fn(inp) - - # no recompile, mode ignored - # note: the ref stack is length 0, and the stack we are checking against has length 2 - # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack - with IgnoredMode(), IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 1) - - # recompile due to new mode on the stack - with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 2) - - # recompile - # tests both ref stack len > runtime stack len for the above guard check - # and ref stack len < runtime stack len for the initial zero mode case - with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # no recompile - with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 3) - - # This is tricky, basically the ignored modes are baked into the guard - # IgnoredMode will be ignored forever by that guard. - # This is okay since we don't expect to be modifying IGNORED_MODES - # in the middle of execution except for the purposes of testing. - torch._dynamo.reset() - - with IgnoredMode(): - fn(inp) - - self.assertEqual(cnt.frame_count, 4) - - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) - def test_torch_function_mode_guards_ignored_types_py(self): - self._run_ignored_mode_types_test() - - def test_torch_function_mode_guards_ignored_types_cpp(self): - self._run_ignored_mode_types_test() - @torch._dynamo.config.patch("enable_cpp_guard_manager", False) def test_torch_function_mode_guards_py(self): self._run_torch_function_mode_guard_test() @@ -324,6 +272,218 @@ def fn(x): fn(inp) self.assertEqual(cnt.frame_count, 2) + def test_nested_torch_function_mode(self): + mode_1_called = False + mode_2_called = False + + def reset_state(): + nonlocal mode_1_called + nonlocal mode_2_called + mode_1_called = False + mode_2_called = False + + ones = torch.ones(2, 2) + zeros = torch.zeros(2, 2) + + class TestMode1(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_1_called + + mode_1_called = True + + if func == torch.add: + return zeros + + return super().__torch_function__(func, types, args, kwargs) + + class TestMode2(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + nonlocal mode_2_called + + mode_2_called = True + + if func == torch.mul: + return ones + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + def fn_2(x): + return torch.mul(x, 3) + torch.add(x, 3) + + inp = torch.ones(2, 2) + 1 + + for fn_i in [fn, fn_2]: + fn_opt = torch.compile(fn_i, fullgraph=True) + with TestMode1(), TestMode2(): + expected = fn_i(inp), mode_1_called, mode_2_called + reset_state() + actual = fn_opt(inp), mode_1_called, mode_2_called + reset_state() + + self.assertEqual(expected, actual) + + def test_torch_function_mode_disable(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + class TestMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + + if func == torch.add: + return torch.zeros(2, 2) + + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + with torch._C.DisableTorchFunctionSubclass(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + with torch._C.DisableTorchFunction(): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_highest_priority(self): + class TestSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args, kwargs=None): + if not kwargs: + kwargs = {} + if func == torch.add: + return torch.ones(2, 2) + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return torch.add(x, 3) + + inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass) + + fn_opt = torch.compile(fn, fullgraph=True) + with TestMode(), torch._dynamo.config.patch( + "traceable_tensor_subclasses", {TestSubclass} + ): + expected = fn(inp) + actual = fn_opt(inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_enter_exit(self): + def fn(x, y): + with TestMode(): + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn, fullgraph=True) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_graph_break(self): + def fn(x, y): + with TestMode(): + torch._dynamo.graph_break() + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_and_pop_graph_break(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + + def test_torch_function_mode_restore_on_exc(self): + @torch._dynamo.disable() + def err(): + raise RuntimeError("test") + + @torch.compile() + def fn(x): + with TestMode(): + x += 1 + err() + x += 2 + return x + + try: + fn(torch.ones(2, 2)) + except RuntimeError: + pass + self.assertEqual(_len_torch_function_stack(), 0) + + def test_torch_function_mode_and_pop_graph_break_mutation(self): + def fn(x, y): + with TestMode(): + z = _pop_torch_function_stack() + z.y = 5 + torch._dynamo.graph_break() + _push_on_torch_function_stack(z) + o = torch.add(x, 3) + o = torch.mul(o, z.y) + + return torch.add(o, y) + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2) + 2) + fn_opt = torch.compile(fn) + + expected = fn(*inp) + actual = fn_opt(*inp) + + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 6e63389b0724b5..329b04fd7d810c 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1994,9 +1994,8 @@ def fn(x, mod): mod = Mod() opt_mod(torch.randn(5, 5), mod) - # fn compiles twice, and forward twice - # (since forward is inlined when fn is compiled) - self.assertEqual(cnts.frame_count, 4) + # fn compiles twice + self.assertEqual(cnts.frame_count, 2) @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) def test_inline_inbuilt_nn_modules(self): @@ -3013,6 +3012,40 @@ def fn(x): with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): helper() + def test_user_defined_nn_module_dynamic(self): + class Conv2d(torch.nn.Conv2d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = torch.nn.functional.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return x + + cnts = torch._dynamo.testing.CompileCounter() + mod1 = Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1)) + mod2 = Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2)) + mod3 = Conv2d(64, 64, kernel_size=(2, 2), stride=(3, 3)) + + opt_mod1 = torch.compile(mod1, backend=cnts, fullgraph=True) + opt_mod2 = torch.compile(mod2, backend=cnts, fullgraph=True) + opt_mod3 = torch.compile(mod3, backend=cnts, fullgraph=True) + + x = torch.randn(1, 64, 64, 64) + opt_mod1(x) + opt_mod2(x) + opt_mod3(x) + + # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints. + self.assertEqual(cnts.frame_count, 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 0824a4164b3f81..58395bb7c914fc 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -1,10 +1,12 @@ # Owner(s): ["module: dynamo"] +import logging from unittest.mock import patch import torch import torch._dynamo.test_case import torch._dynamo.testing import torch._dynamo.utils +import torch._logging from torch._dynamo.utils import dynamo_timed from torch.testing._internal.common_utils import TemporaryFileName @@ -163,20 +165,41 @@ def fn(x, y, z): ) def test_profiler_dynamo_compiled_region(self): - def fn(x, y, z): - return x @ y + z + torch._logging.set_logs(dynamo=logging.INFO) - opt_fn = torch._dynamo.optimize("eager")(fn) + def fn(x, y): + r = y.sum(dim=1) + print(r.shape) + return x * r - inputs = [torch.rand(4, 4) for _ in range(3)] + fn_c = torch.compile(fn) - for _ in range(2): - opt_fn(*inputs) + with torch.profiler.profile(record_shapes=True) as prof: + fn_c( + torch.randn(10), + torch.randn(10, 10), + ) - with torch.profiler.profile() as prof: - opt_fn(*inputs) + fn_c( + torch.randn(10), + torch.randn(10, 15), + ) - self.assertTrue(any(e.name == "Torch-Compiled Region" for e in prof.events())) + for e in prof.events(): + if e.name == "Torch-Compiled Region": + print(e.kwinputs) + self.assertTrue( + any( + e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "0/0_1" + for e in prof.events() + ) + ) + self.assertTrue( + any( + e.name == "Torch-Compiled Region" and e.kwinputs["context"] == "1/0" + for e in prof.events() + ) + ) if __name__ == "__main__": diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index bab95c799a1263..f0cba5132cf3a5 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -315,6 +315,25 @@ def forward(self, x): model(x) self.assertEqual(counter.frame_count, 2) + @patch.object(torch._dynamo.config, "cache_size_limit", 2) + def test_no_recursive_compile_after_cache_limit_hit(self): + def f(x, n): + x = x + n + return g(x, n) + + def g(x, n): + x = x + n + return h(x, n) + + def h(x, n): + return x + n + + counter = torch._dynamo.testing.CompileCounter() + opt_f = torch.compile(f, backend=counter, dynamic=False) + for i in range(10): + opt_f(torch.ones(3), i) + self.assertEqual(counter.frame_count, 2) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 0363ac4b9e2ed9..338da69b20107d 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -20,7 +20,7 @@ from abc import ABC from collections import namedtuple from copy import deepcopy -from enum import Enum +from enum import Enum, IntEnum from functools import wraps from typing import Any, Dict, Iterator, List, Tuple from unittest import mock @@ -457,7 +457,7 @@ def forward( if past_key_value is not None: assert ( len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" real_seq_length += ( past_key_value[0].shape[2] if query_length is None else query_length ) @@ -4546,6 +4546,84 @@ def f(*args): f(*args) self.assertEqual(num_compiles, 1) + def test_issue134451(self): + class BoundingBox2DIndex(IntEnum): + _X = 0 + _Y = 1 + _HEADING = 2 + _LENGTH = 3 + _WIDTH = 4 + + @classmethod + def size(cls): + return 5 + + @classmethod + @property + def X(cls): + return cls._X + + @classmethod + @property + def Y(cls): + return cls._Y + + @classmethod + @property + def HEADING(cls): + return cls._HEADING + + @classmethod + @property + def LENGTH(cls): + return cls._LENGTH + + @classmethod + @property + def WIDTH(cls): + return cls._WIDTH + + @classmethod + @property + def POINT(cls): + # assumes X, Y have subsequent indices + return slice(cls._X, cls._Y + 1) + + @classmethod + @property + def STATE_SE2(cls): + # assumes X, Y, HEADING have subsequent indices + return slice(cls._X, cls._HEADING + 1) + + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self._mlp_states = nn.Sequential( + nn.Linear(10, 20), + nn.ReLU(), + nn.Linear(20, BoundingBox2DIndex.size()), + ) + + def forward(self, x): + agent_states = self._mlp_states(x) + agent_states[..., BoundingBox2DIndex.POINT] = ( + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 + ) + agent_states[..., BoundingBox2DIndex.HEADING] = ( + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi + ) + return agent_states + + model = SimpleModel().eval() + input_tensor = torch.randn(1, 10, dtype=torch.float32) + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) + actual = opt(input_tensor) + try: + expected = model(input_tensor) + except Exception as e: + raise unittest.SkipTest("eager failed, requires Python>=3.12") from e + self.assertEqual(actual, expected) + def test_invalid_seq_unpack(self): def myfn(arg): (a, b) = arg @@ -4922,6 +5000,55 @@ def fn(obj): compiled_str = str(e) self.assertEqual(orig_str, compiled_str) + def test_super_staticmethod(self): + class Parent: + @staticmethod + def greet(): + return 5 + + class Child(Parent): + @staticmethod + def greet(x): + return x * super(Child, Child).greet() + + child = Child() + + def fn(x): + return child.greet(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.ones(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_super_diamond(self): + class A: + def __init__(self): + super().__init__() + self.a = 5 + + class Nothing: + pass + + class B(Nothing, A): + def __init__(self): + super().__init__() + self.b = 10 + + def run(self, x): + return self.a * self.b * x + + def fn(x): + b = B() + return b.run(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_vc_bumped_in_inference_graph(self): @torch.compile def f(x): @@ -5427,15 +5554,17 @@ def gen_inps(len_x, len_y): return x, y def g(x, y): - return tuple(map(f, x, y)) + return map(f, x, y) opt_g = torch.compile(g, fullgraph=True, backend="eager") inps = gen_inps(3, 3) - self.assertEqual(g(*inps), opt_g(*inps)) + self.assertEqual(type(g(*inps)), type(opt_g(*inps))) + self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) inps = gen_inps(3, 5) - self.assertEqual(g(*inps), opt_g(*inps)) + self.assertEqual(type(g(*inps)), type(opt_g(*inps))) + self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) def test_staticmethod_allow_in_graph(self): class MyClass: @@ -5506,7 +5635,7 @@ def f(x, l): z0 = x.sin() z1 = x.sin() y = x + 1 - torch.ops.fsdp.set_.default(x, y) + torch.ops.fsdp.copy_.default(x, y) # z3 and z3 can be CSEd with each other, # but *not* with z0/z1 (they cross a mutation boundary) z2 = x.sin() @@ -5538,7 +5667,7 @@ def f(x, l): z = x.sin() y = x + 1 # graph input has its storage mutated - torch.ops.fsdp.set_.default(x, y) + torch.ops.fsdp.copy_.default(x, y) z2 = x.sin() return z2, l**2 @@ -5703,6 +5832,195 @@ def fn(x): fn(torch.randn(4)) + @requires_cuda + # This test will fail as flip in combination with particular input lenghts + # produces weird results. + # This is under investigations in + # https://github.com/pytorch/pytorch/issues/131805 + @unittest.skip("Skip this flip test for the moment. It is under investigation") + def test_flip_bad_accuracy(self): + import torch + import torch._dynamo.config + import torch._functorch.config + import torch._inductor.config + import torch._inductor.inductor_prims + import torch.fx.experimental._config + + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, arg0_1): + rev = torch.ops.prims.rev.default(arg0_1, [0]) + arg0_1 = None + slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2) + slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2) + add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2) + slice_1 = slice_2 = None + slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2) + slice_4 = torch.ops.aten.slice.Tensor( + add_1, 0, 1, 9223372036854775807, 2 + ) + add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4) + slice_3 = slice_4 = None + slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2) + slice_6 = torch.ops.aten.slice.Tensor( + add_2, 0, 1, 9223372036854775807, 2 + ) + add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6) + slice_5 = slice_6 = None + slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1) + add_2 = None + unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1) + slice_9 = None + unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1) + add_3 = None + cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) + unsqueeze = unsqueeze_1 = None + view = torch.ops.aten.view.default(cat, [2]) + cat = None + slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1) + slice_11 = torch.ops.aten.slice.Tensor( + add_1, 0, 2, 9223372036854775807, 2 + ) + add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11) + slice_10 = slice_11 = None + slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1) + add_1 = None + cat_1 = torch.ops.aten.cat.default([slice_12, add_5]) + slice_12 = add_5 = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1) + cat_1 = None + unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1) + view = None + cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1) + unsqueeze_2 = unsqueeze_3 = None + view_1 = torch.ops.aten.view.default(cat_2, [4]) + cat_2 = None + slice_13 = torch.ops.aten.slice.Tensor( + rev, 0, 2, 9223372036854775807, 2 + ) + add_6 = torch.ops.aten.add.Tensor(view_1, slice_13) + slice_13 = None + slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1) + rev = None + cat_3 = torch.ops.aten.cat.default([slice_14, add_6]) + slice_14 = add_6 = None + constant_pad_nd = torch.ops.aten.constant_pad_nd.default( + view_1, [0, 1], 0.0 + ) + view_1 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1) + cat_3 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1) + constant_pad_nd = None + cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1) + unsqueeze_4 = unsqueeze_5 = None + view_2 = torch.ops.aten.view.default(cat_4, [10]) + cat_4 = None + slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9) + view_2 = None + rev_1 = torch.ops.prims.rev.default(slice_15, [0]) + slice_15 = None + return (rev_1,) + + mod = Repro() + x = torch.arange(9, device=torch.device("cuda")) + + @torch.compile + def f(x): + return mod(x) + + out = f(x) + self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0]) + + # https://github.com/pytorch/pytorch/issues/88813 + def test_return_value_duplication_tensor(self) -> None: + def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return val * 2, val * 2 + + x = torch.randn(2, requires_grad=True) + + expect = fn(x) + self.assertNotEqual( + expect[0].untyped_storage().data_ptr(), + expect[1].untyped_storage().data_ptr(), + ) + + actual = torch.compile(fn, backend="aot_eager")(x) + self.assertNotEqual( + actual[0].untyped_storage().data_ptr(), + actual[1].untyped_storage().data_ptr(), + ) + + # https://github.com/pytorch/pytorch/issues/114344 + def test_return_value_duplication_mixed_grad(self) -> None: + def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + out0 = val + 1 + out1 = val + 1 + return out0, out1 + + x = torch.randn(2, requires_grad=True) + + with torch.enable_grad(): + expect = fn(x) + actual = torch.compile(fn, backend="aot_eager")(x) + + self.assertEqual(expect[0].requires_grad, actual[0].requires_grad) + self.assertEqual(expect[1].requires_grad, actual[1].requires_grad) + + # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371 + def test_return_value_duplication_scalar(self) -> None: + def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, y = val * 2, val * 2 + return x[0], y[0] + + x = torch.randn(2, requires_grad=True) + + expect = fn(x) + self.assertNotEqual( + expect[0].untyped_storage().data_ptr(), + expect[1].untyped_storage().data_ptr(), + ) + + actual = torch.compile(fn, backend="aot_eager")(x) + self.assertNotEqual( + actual[0].untyped_storage().data_ptr(), + actual[1].untyped_storage().data_ptr(), + ) + + def test_torch_compile_in_compile_frame(self): + def gn(x, c=None): + if c is None: + c = 2 + return c * x + + def outer_func(x): + return torch.compile(gn, backend="eager")(x) + + compile_outer = torch.compile(outer_func, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = outer_func(x) + res = compile_outer(x) + self.assertEqual(ref, res) + + # https://github.com/pytorch/pytorch/issues/119162 + def test_inductor_rng_default_dtype(self) -> None: + @torch.compile + def fn(): + tmp = torch.randn(4, 4, dtype=torch.bfloat16) + return tmp + + try: + old = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + out = fn() + finally: + torch.set_default_dtype(old) + # output dtype should be float32 + self.assertEqual(out.dtype, torch.bfloat16) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_sources.py b/test/dynamo/test_sources.py index 48646ac44c5ef3..0f2f7ded33fea5 100644 --- a/test/dynamo/test_sources.py +++ b/test/dynamo/test_sources.py @@ -72,7 +72,6 @@ def forward(self): lambda x, _: CausalLMOutputWithPast(), ) - # breakpoint() torch.export.export(Model(), ()) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 91089220a82a52..cdb7bba77fe91c 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -23,6 +23,8 @@ from torch.testing._internal.inductor_utils import HAS_CUDA +HAS_TLPARSE = shutil.which("tlparse") is not None +requires_tlparse = unittest.skipUnless(HAS_TLPARSE, "requires tlparse") requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( unittest.skipIf, not dist.is_available(), "requires distributed" @@ -213,6 +215,7 @@ def test_cudagraphs(self): self.assertParses() + @requires_tlparse def test_recompiles(self): def fn(x, y): return torch.add(x, y) @@ -257,6 +260,7 @@ def fn(x, y): self.assertParses() + @requires_tlparse def test_example_fn(self): fn_opt = torch._dynamo.optimize("inductor")(example_fn) fn_opt(torch.ones(1000, 1000)) @@ -280,6 +284,7 @@ def test_example_fn(self): self.assertParses() + @requires_tlparse def test_dynamo_error(self): try: fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) @@ -293,12 +298,14 @@ def test_dynamo_error(self): {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 ) self.assertParses() + @requires_tlparse def test_inductor_error(self): import torch._inductor.lowering @@ -331,6 +338,7 @@ def throw(x): {"aot_backward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 ) @@ -463,6 +471,7 @@ def forward(self, x): self.assertParses() + @requires_tlparse def test_graph_breaks(self): @torch._dynamo.optimize("inductor") def fn(x): @@ -496,6 +505,7 @@ def fn(x): # TODO: bring in the trace_source tests once we start emitting bytecode + @requires_tlparse def test_graph_sizes_dynamic(self): def fn(a, b): return a @ b @@ -535,6 +545,7 @@ def fn(a, b): self.assertParses() + @requires_tlparse def test_guards_recompiles(self): def fn(x, ys, zs): return inner(x, ys, zs) @@ -606,6 +617,7 @@ def forward(self, x, y): """, # noqa: B950 ) + @requires_tlparse @torch._inductor.config.patch("fx_graph_cache", True) def test_codecache(self): def fn(a): @@ -648,6 +660,7 @@ def fn(a): ) self.assertParses() + @requires_tlparse @torch._inductor.config.patch("fx_graph_cache", True) @show_chrome_events def test_chromium_event(self): diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 7137fd419b32be..5379405bfbe58e 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -30,6 +30,7 @@ ) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor +from torch.utils._python_dispatch import return_and_correct_aliasing def traceable_subclass(c): @@ -1427,6 +1428,99 @@ def __metadata_guard__(self, x, y): lambda: torch.compile(lambda x: x * x)(x), ) + def test_subclass_constructor_proxying(self): + import dataclasses + from collections import namedtuple + from typing import Any + + @dataclasses.dataclass(frozen=True) + class SubclassTensorArgs: + original_shape: torch.Size + device: torch.device + inner_meta: Any + + SubclassTensorArgs2 = namedtuple( + "SubclassTensorArgs2", + [ + "original_shape", + "device", + "inner_meta", + ], + ) + + class SubclassTensor(torch.Tensor): + @staticmethod + def __new__(cls, a, meta): + shape = a.shape + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = a.device + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, a, meta): + self.a = a + self.meta = meta + + def __repr__(self): + a_repr = repr(self.a) + return f"SubclassTensor({a_repr})" + + def __tensor_flatten__(self): + return ["a"], self.meta + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, _, __): + a = inner_tensors["a"] + return SubclassTensor(a, meta) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args_a = pytree.tree_map( + lambda x: x.a if isinstance(x, SubclassTensor) else x, args + ) + kwargs_a = pytree.tree_map( + lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs + ) + out_a = func(*args_a, **kwargs_a) + out = pytree.tree_map( + lambda x: SubclassTensor( + x, SubclassTensorArgs2(x.shape, x.device, None) + ) + if isinstance(x, torch.Tensor) + else x, + out_a, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + @torch.compile(fullgraph=True) + def f1(x): + meta = SubclassTensorArgs( + x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None) + ) + out = SubclassTensor(x, meta) + return out * out + + x = torch.randn(3, 3) + f1(x) + + @torch.compile(fullgraph=True) + def f1(x): + meta = SubclassTensorArgs2( + x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None) + ) + out = SubclassTensor(x, meta) + return out * out + + x = torch.randn(3, 3) + f1(x) + def test_torch_function_subclass_survives_into_aot_autograd(self): # If you have a tensor subclass that relies on dispatch into the same op # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), @@ -2239,6 +2333,7 @@ def fn(x, y): for ref_v, res_v in zip(values_copy, values): self.assertEqual(ref_v.grad, res_v.grad) + @torch._dynamo.config.patch({"capture_scalar_outputs": True}) def test_unbind(self): # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). # This causes a recompile later on when it realizes the batch and last dim diff --git a/test/dynamo_expected_failures/TestTorch.test_cuda_not_built b/test/dynamo_expected_failures/TestTorch.test_cuda_not_built deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 049ce10c9a123f..bf7ad0a4659cc2 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -350,6 +350,7 @@ aten::maximum aten::maximum.out aten::mean aten::mean.dim +aten::mean.dtype_out aten::mean.out aten::minimum aten::minimum.out diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 58d52f34345e6f..a759903d065591 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -445,6 +445,7 @@ aten::_nested_from_padded aten::_nested_from_padded.out aten::_nested_from_padded_and_nested_example aten::_nested_from_padded_and_nested_example.out +aten::_nested_from_padded_tensor aten::_nested_get_jagged_dummy aten::_nested_get_lengths aten::_nested_get_max_seqlen @@ -1301,8 +1302,6 @@ aten::to_padded_tensor.out aten::topk aten::topk.values aten::transpose_ -aten::transpose_copy.int -aten::transpose_copy.int_out aten::triangular_solve aten::triangular_solve.X aten::unbind_copy.int diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 468ee296e36edb..8c416b17da2a7b 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -209,58 +209,103 @@ def forward(self, x): m(*example_inputs) ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True) joint_ep = _export_forward_backward(ep) - print(joint_ep) - - """ - ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]"): - # No stacktrace found for following nodes - view: "f32[1, 3]" = torch.ops.aten.view.default(arg3_1, [1, 3]); arg3_1 = None - t: "f32[3, 3]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None - addmm: "f32[1, 3]" = torch.ops.aten.addmm.default(arg1_1, view, t); arg1_1 = t = None - view_1: "f32[3]" = torch.ops.aten.view.default(addmm, [3]); addmm = None - _softmax: "f32[3]" = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None - detach_1: "f32[3]" = torch.ops.aten.detach.default(_softmax) - clone: "f32[3]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None - detach_5: "f32[3]" = torch.ops.aten.detach.default(clone); clone = None - _log_softmax: "f32[3]" = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - detach_12: "f32[3]" = torch.ops.aten.detach.default(_log_softmax) - mul: "f32[3]" = torch.ops.aten.mul.Tensor(_log_softmax, detach_5); _log_softmax = None - sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None - neg: "f32[]" = torch.ops.aten.neg.default(sum_1); sum_1 = None - div: "f32[]" = torch.ops.aten.div.Scalar(neg, 1); neg = None - ones_like: "f32[]" = torch.ops.aten.ones_like.default(div, pin_memory = False, memory_format = torch.preserve_format) - div_1: "f32[]" = torch.ops.aten.div.Scalar(ones_like, 1); ones_like = None - neg_1: "f32[]" = torch.ops.aten.neg.default(div_1); div_1 = None - expand: "f32[3]" = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None - mul_1: "f32[3]" = torch.ops.aten.mul.Tensor(expand, detach_5); expand = detach_5 = None - _log_softmax_backward_data: "f32[3]" = torch.ops.aten._log_softmax_backward_data.default(mul_1, detach_12, 0, torch.float32); mul_1 = detach_12 = None - _softmax_backward_data: "f32[3]" = torch.ops.aten._softmax_backward_data.default(_log_softmax_backward_data, detach_1, 0, torch.float32); _log_softmax_backward_data = detach_1 = None - view_2: "f32[1, 3]" = torch.ops.aten.view.default(_softmax_backward_data, [1, 3]); _softmax_backward_data = None - t_1: "f32[3, 1]" = torch.ops.aten.t.default(view_2) - mm: "f32[3, 3]" = torch.ops.aten.mm.default(t_1, view); t_1 = view = None - t_2: "f32[3, 3]" = torch.ops.aten.t.default(mm); mm = None - sum_2: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None - view_3: "f32[3]" = torch.ops.aten.view.default(sum_2, [3]); sum_2 = None - t_3: "f32[3, 3]" = torch.ops.aten.t.default(t_2); t_2 = None - return (div, t_3, view_3) - - Graph signature: ExportGraphSignature( - input_specs=[ - InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='linear.weight', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='linear.bias', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg2_1'), target='lifted_tensor_0', persistent=None), - InputSpec(kind=, arg=TensorArgument(name='arg3_1'), target=None, persistent=None) - ], - output_specs=[ - OutputSpec(kind=, arg=TensorArgument(name='div'), target=None), - OutputSpec(kind=, arg=TensorArgument(name='t_3'), target='linear.weight'), - OutputSpec(kind=, arg=TensorArgument(name='view_3'), target='linear.bias') - ] + self.assertExpectedInline( + str(joint_ep.graph_module.code).strip(), + """\ +def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): + view = torch.ops.aten.view.default(x, [1, 3]); x = None + permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None + addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None + view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None + _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None + alias = torch.ops.aten.alias.default(_softmax) + alias_1 = torch.ops.aten.alias.default(alias); alias = None + clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None + alias_2 = torch.ops.aten.alias.default(clone); clone = None + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None + alias_5 = torch.ops.aten.alias.default(_log_softmax) + alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None + mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None + sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None + neg = torch.ops.aten.neg.default(sum_1); sum_1 = None + div = torch.ops.aten.div.Scalar(neg, 1); neg = None + full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) + div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None + neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None + expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None + exp = torch.ops.aten.exp.default(alias_8); alias_8 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) + mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None + sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None + alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None + sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) + mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None + view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None + permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) + mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None + permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None + sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None + view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None + permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + return (div, permute_3, view_3)""", + ) + ep = joint_ep.run_decompositions() + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): + view = torch.ops.aten.view.default(x, [1, 3]); x = None + permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None + addmm = torch.ops.aten.addmm.default(p_linear_bias, view, permute); p_linear_bias = permute = None + view_1 = torch.ops.aten.view.default(addmm, [3]); addmm = None + _softmax = torch.ops.aten._softmax.default(view_1, 0, False); view_1 = None + alias = torch.ops.aten.alias.default(_softmax) + alias_1 = torch.ops.aten.alias.default(alias); alias = None + clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None + alias_2 = torch.ops.aten.alias.default(clone); clone = None + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None + alias_5 = torch.ops.aten.alias.default(_log_softmax) + alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None + mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None + sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None + neg = torch.ops.aten.neg.default(sum_1); sum_1 = None + div = torch.ops.aten.div.Scalar(neg, 1); neg = None + full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) + div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None + neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None + expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None + exp = torch.ops.aten.exp.default(alias_8); alias_8 = None + sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) + mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None + sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None + alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None + sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) + mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None + sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None + view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None + permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) + mm = torch.ops.aten.mm.default(permute_1, view); permute_1 = view = None + permute_2 = torch.ops.aten.permute.default(mm, [1, 0]); mm = None + sum_4 = torch.ops.aten.sum.dim_IntList(view_2, [0], True); view_2 = None + view_3 = torch.ops.aten.view.default(sum_4, [3]); sum_4 = None + permute_3 = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None + return (div, permute_3, view_3)""", ) - Range constraints: {} - """ def test_joint_dynamic(self) -> None: from torch.export import Dim diff --git a/test/export/test_export.py b/test/export/test_export.py index 90422bace4c162..0ab2f49428b29d 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -18,8 +18,13 @@ import torch.nn.functional as F from functorch.experimental.control_flow import cond, map from torch import Tensor -from torch._decomp import get_decompositions +from torch._decomp import ( + _decomp_table_to_post_autograd_aten, + core_aten_decompositions, + get_decompositions, +) from torch._dynamo.test_case import TestCase +from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.utils import ( get_buffer, @@ -28,6 +33,7 @@ is_param, register_dataclass_as_pytree_node, ) +from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._inductor.compile_fx import split_const_gm from torch._subclasses import FakeTensorMode from torch.export import Dim, export, unflatten @@ -82,7 +88,7 @@ try: from . import testing except ImportError: - import testing + import testing # @manual=fbcode//caffe2/test:test_export-library # The following import pattern matters as `test_export.export` is patched # in other files (like test_export_nonstrict.py). `torch.export.export` # will invalidate the patch. @@ -291,19 +297,50 @@ def _test_export_same_as_eager(self, f, args, kwargs=None): # ) def _check_dynamic_shapes_specs_and_shapes( - self, model, inputs, specs, passing_shapes, failing_shapes + self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False ): + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + ) + from torch.utils._pytree import tree_map + + def _construct_inputs(shapes): + def _is_tensor_leaf(x): + return isinstance(x, tuple) and all(isinstance(y, int) for y in x) + + return tree_map( + lambda x: torch.randn(*x) if _is_tensor_leaf(x) else x, + shapes, + is_leaf=_is_tensor_leaf, + ) + # exports with a list of equivalent dynamic shapes specs, # then tests for pass/fail on list of shapes for _specs in specs: ep = export(model, inputs, dynamic_shapes=_specs) - for shapes in passing_shapes: - test_inputs = (torch.randn(*shape) for shape in shapes) - ep.module()(*test_inputs) - for shapes in failing_shapes: - test_inputs = (torch.randn(*shape) for shape in shapes) - with self.assertRaises(RuntimeError): + eps = [ep] + if test_serdes: + # test dynamic shapes serialization + # test that behavior remains the same when exporting with ser/des specs: + # serialize + deserialize original specs, and export. + ep_serdes = export( + model, + inputs, + dynamic_shapes=_load_dynamic_shapes( + _dump_dynamic_shapes(_specs, inputs) + ), + ) + eps.append(ep_serdes) + + for ep in eps: + for shapes in passing_shapes: + test_inputs = _construct_inputs(shapes) ep.module()(*test_inputs) + for shapes in failing_shapes: + test_inputs = _construct_inputs(shapes) + with self.assertRaises(RuntimeError): + ep.module()(*test_inputs) def test_basic(self): class Module(torch.nn.Module): @@ -476,17 +513,47 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) self.assertEqual(exported_program.module()(*args), m(*args)) - from torch._export import capture_pre_autograd_graph - - gm: torch.fx.GraphModule = capture_pre_autograd_graph( + gm: torch.fx.GraphModule = torch.export.export_for_training( m, args=example_args, dynamic_shapes=dynamic_shapes - ) + ).module() args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256)) self.assertEqual(gm(*args), m(*args)) args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) self.assertEqual(gm(*args), m(*args)) + def test_masked_select_dynamic(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mask = x.ge(0.5) + return torch.masked_select(x, mask) + + example_args = (torch.randn(3, 4, 5),) + dim0_x_max, dim1_x_max = 100, 7 + dynamic_shapes = { + "x": { + 0: Dim("dim0_x", max=dim0_x_max), + 1: Dim("dim1_x_max", max=dim1_x_max), + } + } + m = M() + exported_program: torch.export.ExportedProgram = export( + m, args=example_args, dynamic_shapes=dynamic_shapes + ) + + # Test that the expected upper bound is among the range constraints. + expected_upper_bound = dim0_x_max * dim1_x_max * 5 + vr_upper_bounds = [ + vr.upper for vr in exported_program.range_constraints.values() + ] + self.assertTrue(expected_upper_bound in set(vr_upper_bounds)) + # Test that none of the upper bounds are larger. + for vr_upper in vr_upper_bounds: + self.assertTrue(vr_upper <= expected_upper_bound) + def test_setgrad_lifted_tensor(self): class M(torch.nn.Module): def forward(self, x, y): @@ -650,6 +717,61 @@ def forward(self, x, c): foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False ) + def test_unbacked_to_cond(self): + class M(torch.nn.Module): + def forward(self, a): + az = a.nonzero() + + def true_fn(x): + return (x + 1).sum() + + def false_fn(x): + return (x + 3).sum() + + r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) + return r * 2 + + M()(torch.randn(7)) + torch.export.export(M(), (torch.randn(7),)) + + def test_unbacked_to_cond_passthrough(self): + class M(torch.nn.Module): + def forward(self, a): + az = a.nonzero() + + def true_fn(x): + return x + 1 + + def false_fn(x): + return x + 3 + + r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) + return r * 2 + + M()(torch.randn(7)) + torch.export.export(M(), (torch.randn(7),)) + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_cond_contains_unbacked_no_escape(self): + class M(torch.nn.Module): + def forward(self, a, b1, b2, c): + def true_fn(x): + return x * b1.item() + + def false_fn(x): + return x * b2.item() + + r = torch.cond(a, true_fn, false_fn, (c,)) + return r * 2 + + args = ( + torch.tensor(True), + torch.tensor([4]), + torch.tensor([4]), + torch.randn(10, requires_grad=True), + ) + torch.export.export(M(), args) + def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: @@ -772,6 +894,44 @@ def forward(self, x): torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21) ) + @testing.expectedFailureTrainingIRToRunDecompNonStrict # TODO(pianpwk): user_output signature + def test_real_tensor_for_max_op(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + x = x[x > 0] + y = y[y > 0] + return max(x.shape[0], y.shape[0]) + + model = Foo() + inputs = (torch.randn(64), torch.randn(64)) + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + ep = export(model, inputs) + + self.assertEqual(ep.module()(*inputs), model(*inputs)) + x = torch.zeros(64) + y = torch.ones(64) + self.assertEqual(ep.module()(x, x), model(x, x)) + self.assertEqual(ep.module()(x, y), model(x, y)) + + def test_export_script_module(self): + class Foo(torch.nn.Module): + def forward(self, rv: torch.Tensor, t: torch.Tensor): + i = t.item() + return rv + i + + foo = Foo() + foo_script = torch.jit.script(foo) + inp = (torch.zeros(3, 4), torch.tensor(7)) + + with self.assertRaisesRegex( + ValueError, "Exporting a ScriptModule is not supported" + ): + export(foo_script, inp) + + from torch._export.converter import TS2EPConverter + + TS2EPConverter(foo_script, inp).convert() + def test_torch_fn(self): class M1(torch.nn.Module): def __init__(self) -> None: @@ -912,14 +1072,17 @@ def forward(self, x): x = self.linear(x) return torch.ops.aten.chunk.default(x, 3, 0) - gm = ( - torch.export.export( - Foo(), - (torch.randn(3, 3),), + ep = torch.export.export(Foo(), (torch.randn(3, 3),)) + if IS_FBCODE: + ep = ep.run_decompositions( + {}, _preserve_ops=(torch.ops.aten.linear.default,) ) - .run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,)) - .graph_module - ) + else: + decomp_table = _decomp_table_to_post_autograd_aten() + del decomp_table[torch.ops.aten.linear.default] + ep = ep.run_decompositions(decomp_table) + + gm = ep.graph_module # linear is CompositeImplicitAutograd functional op so we should preserve it # chunk is CompositeImplicitAutograd non-functional op we decompose. self.assertExpectedInline( @@ -934,16 +1097,6 @@ def forward(self, p_linear_weight, p_linear_bias, x): return (getitem, getitem_1, getitem_2)""", ) - # TODO(yidi) - # Expected failure for test cases that calls run_decomposition(). - # The top-level cond node has pre-existing metadata, - # which overrides the metadata for operators in subgraph due to interpreter.run(), - # where cond is a single node in the interpreter.run(). And we preserve metadata - # by copying current node's metadata for all nodes created during interpreting. - @testing.expectedFailurePreDispatchRunDecomp - @testing.expectedFailureRetraceability - @testing.expectedFailureTrainingIRToRunDecomp # T193700910 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_export_cond_preserve_torch_fn_for_subgraphs(self): class MySubModule(torch.nn.Module): def foo(self, x): @@ -1290,6 +1443,7 @@ def forward(self, x, y): "dy - 6 = 6" not in exc.args[0] ) # don't suggest fix for non-root dim + @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") def test_keep_composite_ops_invalid(self): class Foo(torch.nn.Module): def __init__(self) -> None: @@ -1300,32 +1454,29 @@ def forward(self, x): x = self.linear(x) return torch.ops.aten.chunk.default(x, 3, 0) - with self.assertRaisesRegex( - RuntimeError, "aten.chunk.default is a mutating/aliasing op" - ): + def _(*args, **kwargs): + return NotImplemented + + with self.assertWarnsRegex(UserWarning, "The op aten.chunk.default"): _ = torch.export.export( Foo(), (torch.randn(3, 3),), - ).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,)) + ).run_decompositions({torch.ops.aten.chunk.default: _}) - with self.assertRaisesRegex( - RuntimeError, "aten.sym_size.default is a metadata query function" - ): + with self.assertWarnsRegex(UserWarning, "The op aten.sym_size.default"): _ = torch.export.export( Foo(), (torch.randn(3, 3),), - ).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,)) + ).run_decompositions({torch.ops.aten.sym_size.default: _}) - with self.assertRaisesRegex( - RuntimeError, - "We can't detect aten.native_batch_norm.default as a functional op statically", + with self.assertWarnsRegex( + UserWarning, + "The op aten.native_batch_norm.default", ): _ = torch.export.export( Foo(), (torch.randn(3, 3),), - ).run_decompositions( - {}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,) - ) + ).run_decompositions({torch.ops.aten.native_batch_norm.default: _}) def test_keep_composite_ops_linear_convd(self): class MyLinear(torch.nn.Module): @@ -1353,10 +1504,14 @@ def forward(self, x, y): ep = torch.export.export( Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) - ep_has_linear_convd = ep.run_decompositions( - decomp_table={}, - _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, - ) + if IS_FBCODE: + ep_has_linear_convd = ep.run_decompositions( + {}, + _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, + ) + else: + ep_has_linear_convd = ep.run_decompositions({}) + self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), """\ @@ -1370,13 +1525,19 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_ return (add,)""", ) - ep_has_convd = ep.run_decompositions( - decomp_table=None, - _preserve_ops=[ - torch.ops.aten.conv2d.default, - torch.ops.aten.conv1d.default, - ], - ) + if IS_FBCODE: + ep_has_convd = ep.run_decompositions( + _preserve_ops=( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + ) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + del decomp_table[torch.ops.aten.conv1d.default] + + ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ @@ -1392,10 +1553,15 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_ add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) + if IS_FBCODE: + ep_has_convd = ep_has_convd.run_decompositions( + _preserve_ops=(torch.ops.aten.conv2d.default,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] - ep_has_convd = ep_has_convd.run_decompositions( - decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] - ) + ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ @@ -1435,71 +1601,139 @@ def forward(self, x, y): x_linear = self.linear(x_conv) return x_linear.cos() + y_conv_1d.sum() - ep = torch.export._trace._export_for_training( + ep = torch.export.export_for_training( Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) - ep_has_linear_convd = ep.run_decompositions( - decomp_table={}, - _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, - ) + + if IS_FBCODE: + ep_has_linear_convd = ep.run_decompositions( + {}, + _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, + ) + else: + ep_has_linear_convd = ep.run_decompositions( + decomp_table={}, + ) + self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), """\ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): - convolution = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - convolution_1 = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None - view = torch.ops.aten.view.default(convolution, [31680, 98]); convolution = None - t = torch.ops.aten.t.default(b_linear_weight); b_linear_weight = None - addmm = torch.ops.aten.addmm.default(b_linear_bias, view, t); b_linear_bias = view = t = None - view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None - cos = torch.ops.aten.cos.default(view_1); view_1 = None - sum_1 = torch.ops.aten.sum.default(convolution_1); convolution_1 = None + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None + linear = torch.ops.aten.linear.default(conv2d, b_linear_weight, b_linear_bias); conv2d = b_linear_weight = b_linear_bias = None + cos = torch.ops.aten.cos.default(linear); linear = None + sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) - ep_has_convd = ep.run_decompositions( - decomp_table=None, - _preserve_ops=[ - torch.ops.aten.conv2d.default, - torch.ops.aten.conv1d.default, - ], - ) + if IS_FBCODE: + ep_has_convd = ep.run_decompositions( + _preserve_ops=( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, + ) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + del decomp_table[torch.ops.aten.conv1d.default] + + ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) + self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): - convolution = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - convolution_1 = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None - view = torch.ops.aten.view.default(convolution, [31680, 98]); convolution = None - t = torch.ops.aten.t.default(b_linear_weight); b_linear_weight = None - addmm = torch.ops.aten.addmm.default(b_linear_bias, view, t); b_linear_bias = view = t = None + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None + view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None + permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]); b_linear_weight = None + addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute); b_linear_bias = view = permute = None view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None cos = torch.ops.aten.cos.default(view_1); view_1 = None - sum_1 = torch.ops.aten.sum.default(convolution_1); convolution_1 = None + sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []); conv1d = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) - ep_has_convd = ep_has_convd.run_decompositions( - decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] - ) + if IS_FBCODE: + ep_has_convd = ep_has_convd.run_decompositions( + _preserve_ops=(torch.ops.aten.conv2d.default,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) + self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): - convolution = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None - convolution_1 = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None - view = torch.ops.aten.view.default(convolution, [31680, 98]); convolution = None + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None + view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]); b_linear_weight = None addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute); b_linear_bias = view = permute = None view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None cos = torch.ops.aten.cos.default(view_1); view_1 = None - sum_1 = torch.ops.aten.sum.dim_IntList(convolution_1, []); convolution_1 = None + sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []); convolution = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) + @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") + def test_error_when_passing_mutating_primitive_op(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x.sin() + + ep = export(Foo(), (torch.ones(3, 3),)) + with self.assertWarnsRegex( + UserWarning, + "The op aten.index_put_.default", + ): + ep.run_decompositions({torch.ops.aten.index_put_.default: None}) + + def test_if_post_autograd_op_preserved(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x.sin() + x.sum() + + ep = export(Foo(), (torch.ones(3, 3),)) + if IS_FBCODE: + ep_preserve_sum = ep.run_decompositions( + _preserve_ops=(torch.ops.aten.sum.default,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.sum.default] + ep_preserve_sum = ep.run_decompositions(decomp_table) + + # Even though we are decomposing to core aten which should make + # sum into sum.dim_IntList, we explicitly marked it to not do that. + self.assertExpectedInline( + str(ep_preserve_sum.graph_module.code).strip(), + """\ +def forward(self, x): + sin = torch.ops.aten.sin.default(x) + sum_1 = torch.ops.aten.sum.default(x); x = None + add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None + return (add,)""", + ) + + ep_no_preserve_sum = ep.run_decompositions() + self.assertExpectedInline( + str(ep_no_preserve_sum.graph_module.code).strip(), + """\ +def forward(self, x): + sin = torch.ops.aten.sin.default(x) + sum_1 = torch.ops.aten.sum.dim_IntList(x, []); x = None + add = torch.ops.aten.add.Tensor(sin, sum_1); sin = sum_1 = None + return (add,)""", + ) + def test_set_grad_empty(self): class M(torch.nn.Module): def forward(self, x): @@ -1568,7 +1802,7 @@ def forward(self, x): return self.linear(x) eager_model = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model, (torch.ones(2, 2),) ) self.assertExpectedInline( @@ -1607,7 +1841,7 @@ def forward(self, x): eager_model_for_export = Foo() eager_model_for_testing = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model_for_export, (torch.ones(4, 4),) ) self.assertExpectedInline( @@ -1652,7 +1886,7 @@ def forward(self, x): eager_model_for_export_training = Foo() eager_model_for_export_inference = Foo() eager_model_for_testing = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model_for_export_training, (torch.ones(4, 4),), dynamic_shapes=({0: Dim("x")},), @@ -1689,7 +1923,7 @@ def forward(self, container): return x + y + self.buffer.sum() eager_model = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model, ([torch.ones(4, 4), torch.ones(4, 4)],), ) @@ -1715,7 +1949,7 @@ def forward(self, x): return self.linear(x) + self.buffer.sum() eager_model = Foo() - ep_for_training = torch.export._trace._export_for_training( + ep_for_training = torch.export.export_for_training( eager_model, (torch.ones(2, 2),), ) @@ -1725,9 +1959,9 @@ def forward(self, x): """\ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): add = torch.ops.aten.add.Tensor(b_buffer, 5); b_buffer = None - t = torch.ops.aten.t.default(p_linear_weight); p_linear_weight = None - addmm = torch.ops.aten.addmm.default(p_linear_bias, x, t); p_linear_bias = x = t = None - sum_1 = torch.ops.aten.sum.default(add) + permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None + addmm = torch.ops.aten.addmm.default(p_linear_bias, x, permute); p_linear_bias = x = permute = None + sum_1 = torch.ops.aten.sum.dim_IntList(add, []) add_1 = torch.ops.aten.add.Tensor(addmm, sum_1); addmm = sum_1 = None return (add, add_1)""", ) @@ -1771,8 +2005,6 @@ def forward(self, x, y, y1, z): ) def test_static_dim_constraints(self): - from torch.export.dynamic_shapes import DIM - class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1801,7 +2033,7 @@ def forward(self, x, y, z): ((dx, None), (dy, 4), (dz, 3)), ((None, 6), (5, None), (None, None)), ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}), - (None, None, (DIM.STATIC, DIM.STATIC)), + (None, None, (Dim.STATIC, Dim.STATIC)), ]: ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) self.assertEqual(foo(*inputs), ep.module()(*inputs)) @@ -1948,9 +2180,7 @@ def forward(self, inp: Inp): self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") def test_mismatched_dynamic_shapes(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC class M(torch.nn.Module): def forward(self, x): @@ -1982,7 +2212,7 @@ def forward(self, x): + re.escape( "specified at `dynamic_shapes[0]['k']['k'][0]` " "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - " where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)" + " where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)" ), ): export(M(), inputs, dynamic_shapes=dynamic_shapes) @@ -2059,7 +2289,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, re.escape( - "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export." ), ): @@ -2235,8 +2465,6 @@ def forward(self, t): M = M_v3 export(N(), (t,), strict=strict) - @testing.expectedFailureTrainingIRToRunDecomp - @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? @testing.expectedFailureSerDer # T195866111 def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): # suggested fixes for data-dependent errors only work in non-strict mode @@ -2367,6 +2595,44 @@ def forward(self, xs, y): strict=strict, ) + class Box: + def __init__(self, content): + self.content = content + + from torch.utils._pytree import register_pytree_node + + register_pytree_node( + Box, + lambda box: ([box.content], None), # flatten_fn + lambda contents, _context: Box(*contents), # unflatten_fn + flatten_with_keys_fn=None, # unflatten_fn + serialized_type_name="test_no_suggested_fixes_for_data_dependent_errors.Box", + ) + + class cf_stacklist_udd(torch.nn.Module): + def forward(self, xs, y): + box = Box(y.item()) + # box.content is not a local, so we can't suggest a fix + return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + + with self.assertRaisesRegex( + error_type, + "Could not guard on data-dependent expression u0 < 0", + ): + export( + cf_stacklist_udd(), + ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), + strict=strict, + ) + + def test_tolist(self): + class M(torch.nn.Module): + def forward(self, x): + return x.tolist() + + ep = export(M(), (torch.ones(3, dtype=torch.int),)) + self.assertEqual(ep.module()(torch.tensor([1, 2, 3])), [1, 2, 3]) + def test_if_functional(self): class Module(torch.nn.Module): def forward(self, x): @@ -2460,35 +2726,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ): em.module()(x) - def test_mark_and_auto_dynamic(self): - # for this use case, mark_dynamic() and AUTO should have same effect. - # check that same symbol gets allocated to both dims without raising constraint violation. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + def test_dont_duck_size_for_auto_dynamic(self): + AUTO, STATIC = Dim.AUTO, Dim.STATIC class Foo(torch.nn.Module): def forward(self, x, y): - torch._check(x.shape[0] == y.shape[0]) - torch._check(x.shape[0] <= 64) - return x + 2, y + 2 - - inputs = (torch.randn(4, 4), torch.randn(4, 4)) - ep_auto = torch.export.export( - Foo(), inputs, dynamic_shapes={"x": (AUTO, None), "y": (AUTO, None)} - ) - torch._dynamo.mark_dynamic(inputs[0], 0) - torch._dynamo.mark_dynamic(inputs[1], 0) - ep_dynamic = torch.export.export(Foo(), inputs) - - # test both programs have same effect - for ep in [ep_auto, ep_dynamic]: - gm = ep.module() - gm(torch.randn(32, 4), torch.randn(32, 4)) - gm(torch.randn(1, 4), torch.randn(1, 4)) - with self.assertRaises(RuntimeError): - gm(torch.randn(33, 4), torch.randn(32, 4)) - gm(torch.randn(128, 4), torch.randn(128, 4)) + # x: [s0, s1], y: [s0 + 1, 4] + assert y.shape[1] == 4 + assert x.shape[0] == y.shape[0] - 1 + return x * 2, y * 2 + + # duck sizing would make all static based on these sample inputs + inputs = (torch.randn(4, 4), torch.randn(5, 4)) + shapes = { + "x": (AUTO, AUTO), + "y": (AUTO, AUTO), + } + ep = export(Foo(), inputs, dynamic_shapes=shapes) + ep.module()(torch.randn(6, 3), torch.randn(7, 4)) @testing.expectedFailureRetraceability # T183144629 def test_map(self): @@ -3425,8 +3680,6 @@ def _patch_config(kwargs): ): _ = export(mod, inp, strict=True) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_static(self): class Module(torch.nn.Module): def forward(self, x): @@ -3441,8 +3694,6 @@ def forward(self, x): for op in ops: self.assertIn(op, (torch.ops.aten._to_copy.default,)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_dynamic(self): class Module(torch.nn.Module): def forward(self, x): @@ -3461,8 +3712,6 @@ def forward(self, x): for op in ops: self.assertIn(op, (torch.ops.aten._to_copy.default,)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_mutation(self): class Module(torch.nn.Module): def forward(self, x): @@ -3475,8 +3724,6 @@ def forward(self, x): ): export(Module(), (torch.tensor(1, device="cpu"),)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_float_conversion(self): class Module(torch.nn.Module): def forward(self, x): @@ -3491,8 +3738,6 @@ def forward(self, x): for op in ops: self.assertIn(op, (torch.ops.aten._to_copy.default,)) - @testing.expectedFailureTrainingIRToRunDecomp # T193700396 - @testing.expectedFailureTrainingIRToRunDecompNonStrict def test_device_to_mutation_float(self): class Module(torch.nn.Module): def forward(self, x): @@ -3550,6 +3795,18 @@ def forward(self, x): ) ) + def test_use_embedding_twice(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.embed = torch.nn.Embedding(4, 4) + + def forward(self, x): + return self.embed(x) + self.embed.weight[x] + + inputs = (torch.tensor([0, 1, 2, 3]),) + ep = export(Foo(), inputs) + def test_module_with_dict_container_inp_out(self): class MyLinear(torch.nn.Module): def __init__(self) -> None: @@ -4343,12 +4600,9 @@ def forward(self, x): f = Module() if is_non_strict_test(self._testMethodName): error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode - error_msg = r"Could not guard on data-dependent expression" else: error = torch._dynamo.exc.UserError - error_msg = ( - r"Tried to use data-dependent value in the subsequent computation" - ) + error_msg = r"Could not guard on data-dependent expression" with self.assertRaisesRegex(error, error_msg): _ = export(f, (torch.tensor(6),)) @@ -4508,6 +4762,134 @@ def forward(self, x): self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) self.assertEqual(id(state_dict), id(ep.state_dict)) + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") + def test_export_for_inference_e2e(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin(x) + + inp = (torch.randn(5, 10),) + m = M() + + decomp_table = torch.export.core_aten_decompositions() + + def _custom_decomp_for_linear(x, weight, bias): + return x + bias.sum() + + decomp_table[torch.ops.aten.linear.default] = _custom_decomp_for_linear + del decomp_table[torch.ops.aten.sum.default] + ep = torch.export.export_for_inference( + m, inp, decomp_table=decomp_table, dynamic_shapes={"x": {0: Dim("batch")}} + ) + + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, p_lin_weight, p_lin_bias, x): + sum_1 = torch.ops.aten.sum.default(p_lin_bias); p_lin_bias = None + add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None + return (add,)""", + ) + + ep_core = ep.run_decompositions() + + self.assertExpectedInline( + str(ep_core.graph_module.code).strip(), + """\ +def forward(self, p_lin_weight, p_lin_bias, x): + sum_1 = torch.ops.aten.sum.dim_IntList(p_lin_bias, []); p_lin_bias = None + add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None + return (add,)""", + ) + + with self.assertRaisesRegex(RuntimeError, "Expected input"): + ep.module()(torch.randn(4, 12)) + + with self.assertRaisesRegex(RuntimeError, "Expected input"): + ep_core.module()(torch.randn(4, 12)) + + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") + def test_export_decomp_torture_case_1(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin(x) + + inp = (torch.randn(5, 10),) + m = M() + ep = export(m, inp) + + def custom_decomp_callable(x, weight, bias): + return x + bias + + decomp_table = core_aten_decompositions() + decomp_table[torch.ops.aten.linear.default] = custom_decomp_callable + core_aten_ep = ep.run_decompositions(decomp_table) + self.assertExpectedInline( + str(core_aten_ep.graph_module.code).strip(), + """\ +def forward(self, p_lin_weight, p_lin_bias, x): + add = torch.ops.aten.add.Tensor(x, p_lin_bias); x = p_lin_bias = None + return (add,)""", + ) + + @unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode") + def test_export_decomp_torture_case_2(self): + class MyLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.randn(20, 98) + self.bias = torch.randn(20) + + def forward(self, x): + return torch.nn.functional.linear(x, self.weight, self.bias) + + class Foo(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(16, 33, 3) + self.conv1d = torch.nn.Conv1d(16, 33, 3) + self.linear = MyLinear() + + def forward(self, x, y): + x_conv = self.conv(x) + y_conv_1d = self.conv1d(y) + x_linear = self.linear(x_conv) + return x_linear.cos() + y_conv_1d.sum() + + ep = export(Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))) + ep_has_linear_convd = ep.run_decompositions(decomp_table={}) + + def _decompose_linear_custom(x, weight, bias): + return torch.matmul(x, weight.T) + 2 * bias + + ep_decompose_linear = ep_has_linear_convd.run_decompositions( + decomp_table={torch.ops.aten.linear.default: _decompose_linear_custom} + ) + + self.assertExpectedInline( + str(ep_decompose_linear.graph_module.code).strip(), + """\ +def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): + conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None + conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None + permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None + matmul = torch.ops.aten.matmul.default(conv2d, permute); conv2d = permute = None + mul = torch.ops.aten.mul.Tensor(c_linear_bias, 2); c_linear_bias = None + add = torch.ops.aten.add.Tensor(matmul, mul); matmul = mul = None + cos = torch.ops.aten.cos.default(add); add = None + sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None + add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None + return (add_1,)""", + ) + def test_export_decomps_dynamic(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -5078,7 +5460,7 @@ def forward( eps, ), ) - ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) + ep.run_decompositions() self.assertEqual( ep.module()( input, weight, bias, running_mean, running_var, training, momentum, eps @@ -5308,7 +5690,7 @@ def forward(self, t, dim, index, src, **kwargs): output = model(t, dim, index, src) ep = torch.export.export(model, args=(t, dim, index, src)) - ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) + ep = ep.run_decompositions() self.assertEqual(ep.module()(t, dim, index, src), output) def test_fqn(self): @@ -5547,7 +5929,6 @@ def forward(self, x): ) ) - # Guard validation upsets the guard def test_cond_with_module_stack_export_with(self): class Bar(torch.nn.Module): def __init__(self) -> None: @@ -6165,7 +6546,6 @@ def forward(self, mul, add, add_1): real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes] self.assertEqual(expected_names_and_ops, real_names_and_ops) - @testing.expectedFailureRetraceability def test_placeholder_naming_collisions_hoo_subgraphs(self): # test collisions between user inputs, top-level nodes, and HOO subgraph nodes class Foo(torch.nn.Module): @@ -6221,11 +6601,7 @@ def forward(self, x, mul, mul_1): # (please never do this) class Foo(torch.nn.Module): def forward(self, input, true_graph, body_graph): - def map_body(x, y): - return x + y - - x = map(map_body, input, body_graph[0]) - x = x + true_graph[0] + true_graph[1] + x = input + true_graph[0] + true_graph[1] x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) return x @@ -6237,7 +6613,6 @@ def map_body(x, y): ) ep = export(Foo(), inputs) expected_getattr_names = [ - "body_graph_1", "true_graph_2", "false_graph_0", "true_graph_3", @@ -6847,13 +7222,31 @@ def forward(self, x): if node.op == "call_function": self.assertTrue(False) + @testing.expectedFailureTrainingIRToRunDecomp # T200904004 + @testing.expectedFailureTrainingIRToRunDecompNonStrict + def test_istft_op(self): + class istft_class(torch.nn.Module): + def forward(self, spec): + window = torch.hann_window(1024).type(torch.FloatTensor) + return torch.istft( + spec, + n_fft=1024, + hop_length=512, + window=window, + length=144000, + ) + + model = istft_class() + real_part = torch.randn(1, 513, 282, dtype=torch.float32) + imaginary_part = torch.randn(1, 513, 282, dtype=torch.float32) + spec = torch.complex(real_part, imaginary_part) + export(model, (spec,)) + def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked # with runtime asserts. Check that we specialize to static values when the program says so. - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 1: direct equality between symbols class SimpleEquality(torch.nn.Module): @@ -6880,6 +7273,7 @@ def forward(self, x, y, z): ((4, 4), (4, 4), (4, 3)), ((4, 4), (5, 4), (4, 5)), ], + test_serdes=True, ) # static s1 self._check_dynamic_shapes_specs_and_shapes( @@ -6900,6 +7294,7 @@ def forward(self, x, y, z): ((1, 1), (1, 1), (1, 1)), ((0, 9), (0, 9), (0, 9)), ], + test_serdes=True, ) # fully static self._check_dynamic_shapes_specs_and_shapes( @@ -6915,12 +7310,11 @@ def forward(self, x, y, z): ((1, 3), (1, 3), (1, 3)), ((0, 9), (0, 9), (0, 9)), ], + test_serdes=True, ) def test_automatic_dynamic_shapes_constant_relation(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 2: related by constant: s0 + 4 = s1 class OffBy4(torch.nn.Module): @@ -6944,6 +7338,7 @@ def forward(self, x, y): failing_shapes=[ ((10,), (13,)), ], + test_serdes=True, ) # static s1 should specialize s0 self._check_dynamic_shapes_specs_and_shapes( @@ -6960,12 +7355,11 @@ def forward(self, x, y): ((3,), (7,)), ((2,), (6,)), ], + test_serdes=True, ) def test_automatic_dynamic_shapes_linear_relation(self): - from torch.export.dynamic_shapes import DIM - - AUTO, STATIC = DIM.AUTO, DIM.STATIC + AUTO, STATIC = Dim.AUTO, Dim.STATIC # case 3: linear relation class LinearRel(torch.nn.Module): @@ -6994,6 +7388,7 @@ def forward(self, x, y): ((34,), (8,)), ((22,), (5,)), ], + test_serdes=False, ) # static s1 shouldn't actually specialize s0 (guard: (s0 + 2) // 4 == 5) self._check_dynamic_shapes_specs_and_shapes( @@ -7012,6 +7407,7 @@ def forward(self, x, y): failing_shapes=[ ((33,), (8,)), ], + test_serdes=False, ) # but static s0 will definitely specialize s1 (guard: (21 + 2) // 4 == s1 -> 5 == s1) self._check_dynamic_shapes_specs_and_shapes( @@ -7026,6 +7422,280 @@ def forward(self, x, y): failing_shapes=[ ((22,), (5,)), ], + test_serdes=True, + ) + + def test_dynamic_shapes_serdes_generic(self): + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + ) + + class Foo(torch.nn.Module): + def forward(self, a, b, c, d): + if d == "hello": + x = a[0] + a[1][1:] + b = torch.cat([b, b], dim=0).reshape([-1, 1]) + return x + b, c * 2 + + # test de/serialization on some generic specs + dz = Dim("dz", min=4, max=16) + dx = 2 * dz + dy = dx + 1 + inputs = ( + [ + torch.randn(8, 4), + torch.randn(9, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (dz,), + "c": None, + "d": None, + } + ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes) + self._check_dynamic_shapes_specs_and_shapes( + Foo(), + inputs, + [dynamic_shapes], + [ + ([(16, 4), (17, 4)], (8,), (4, 4), "hello"), + ([(24, 4), (25, 4)], (12,), (4, 4), "hello"), + ], + [ + ([(16, 4), (17, 4)], (8,), (5, 5), "hello"), + ], + test_serdes=True, + ) + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, inputs), + """DynamicShapesSpec(dynamic_shapes=([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), dims={'dz': RootDim(min=4, max=16, derived=['2*dz', '2*dz + 1'])})""", + ) + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True), + """{'dynamic_shapes': ([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), 'dims': {'dz': {'min': 4, 'max': 16, 'derived': ['2*dz', '2*dz + 1']}}}""", + ) + ((dx, _), (dy, _)), (dz,), (_, _), _ = _load_dynamic_shapes( + _dump_dynamic_shapes(dynamic_shapes, inputs) + ) + self.assertEqual(dx.root, dz) + self.assertEqual(dy.root, dz) + + def test_dynamic_shapes_serdes_various(self): + # serialization for dataclass inputs, Dim.AUTO/STATIC, and kwargs + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + ) + + auto, static = Dim.AUTO, Dim.STATIC + + @dataclass + class Input: + a: Tensor + b: Tensor + + register_dataclass_as_pytree_node( + Input, + serialized_type_name="test_dynamic_shapes_serdes_various.Input", + ) + + class Foo(torch.nn.Module): + def forward(self, x, y, z): + return x - torch.randn(4), y.a + y.b + z[1:] + + args = (torch.randn(4, 4),) + kwargs = { + "y": Input(a=torch.randn(8, 8), b=torch.randn(8, 8)), + "z": torch.randn(9, 8), + } + dynamic_shapes = { + "x": (auto, static), + "y": [(auto, auto), (auto, auto)], + "z": (auto, 8), + } + + # dump dynamic_shapes + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, args, kwargs), + """DynamicShapesSpec(dynamic_shapes=(['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), dims={})""", + ) + self.assertExpectedInline( + _dump_dynamic_shapes(dynamic_shapes, args, kwargs, to_dict=True), + """{'dynamic_shapes': (['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), 'dims': {}}""", + ) + + def test_dynamic_shapes_serdes_user_errors(self): + # check error messages for dynamic shapes de/serialization + from torch._export.serde.dynamic_shapes import ( + _dump_dynamic_shapes, + _load_dynamic_shapes, + DynamicShapesSpec, + RootDim, + ) + from torch._export.serde.serialize import _dataclass_to_dict + + # this stuff should be well tested in `test_mismatched_dynamic_shapes` + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]['k']` " + "is a , but `dynamic_shapes[0]['k']` is a " + ), + ): + dynamic_shapes = {"x": {"k": (Dim("dx"), Dim("dy"))}} + _dump_dynamic_shapes(dynamic_shapes, ({"k": [torch.randn(4, 4)]},)) + + # loading with from_dict=True/False + spec = DynamicShapesSpec( + dynamic_shapes=[["dx"]], + dims={"dx": RootDim(min=4, max=16, derived=[])}, + ) + spec_dict = _dataclass_to_dict(spec) + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "With from_dict=True, expected `spec` to be a dict, " + "got " + ), + ): + _load_dynamic_shapes(spec, from_dict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape("Expected `spec` to be a DynamicShapesSpec, got "), + ): + _load_dynamic_shapes(spec_dict, from_dict=False) + + self.assertExpectedInline( + _load_dynamic_shapes(spec, from_dict=False), + """[[]]""", + ) + + # check incorrect info in dims + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Expected dims in `spec['dims']` to map `min` to an int, got dx: None" + ), + ): + spec = { + "dynamic_shapes": [["dx"]], + "dims": { + "dx": { + "min": None, + "max": 4, + "derived": [], + }, + }, + } + _load_dynamic_shapes(spec, from_dict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + "got dx which is not in dict_keys(['dy'])" + ), + ): + spec = { + "dynamic_shapes": [["dx"]], + "dims": { + "dy": { + "min": 2, + "max": 4, + "derived": [], + }, + }, + } + _load_dynamic_shapes(spec, from_dict=True) + + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + re.escape( + "Expected derived expressions to be linear expressions, got dx**2 + 4" + ), + ): + spec = { + "dynamic_shapes": [["dx"]], + "dims": { + "dx": { + "min": 2, + "max": 4, + "derived": ["dx**2 + 4"], + }, + }, + } + _load_dynamic_shapes(spec, from_dict=True) + + @testing.expectedFailureNonStrict + @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? + @testing.expectedFailureSerDer # T195866111 + def test_hints_wrapper(self): + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + x = x + y + + def inner_body_fn(x, y): + x = torch.relu(x) + x = x + y + return x + + def outer_body_fn(x, y): + x = hints_wrapper( + inner_body_fn, (x, y), {}, hints={"inner_body": True} + ) + x = torch.abs(x) + return x + + res = hints_wrapper( + outer_body_fn, (x, y), {}, hints={"outer_body": True} + ) + return res + + x = torch.randn(2, 4) + y = torch.ones(4) + + ep = export(M(), (x, y)) + export_res = ep.module()(x, y) + ref_res = M()(x, y) + self.assertEqual(export_res, ref_res) + self.assertExpectedInline( + normalize_gm(ep.graph_module.print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[2, 4]", y: "f32[4]"): + add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None + + hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None + getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + return (getitem,) + + class hints_wrapper_body_graph_0(torch.nn.Module): + def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): + hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0 + hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None + getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None + abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None + return (abs_1,) + + class hints_wrapper_body_graph_0(torch.nn.Module): + def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): + relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None + add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None + return (add,) +""", ) @@ -7577,13 +8247,6 @@ def forward(self, x): arg = node.args[0] self.assertTrue(arg.op == "placeholder") - def test_tolist_nonstrict_output(self): - class M(torch.nn.Module): - def forward(self, x): - x.tolist() - - ep = torch.export.export(M(), (torch.ones(3),), strict=False) - def test_preserve_non_cia_op(self): class M(torch.nn.Module): def forward(self, x): @@ -7594,10 +8257,15 @@ def forward(self, x): ep.graph_module.code ) - ep = ep.run_decompositions( - decomp_table=get_decompositions([torch.ops.aten.elu.default]), - _preserve_ops=[torch.ops.aten.elu.default], - ) + if IS_FBCODE: + ep = ep.run_decompositions(_preserve_ops=(torch.ops.aten.elu.default,)) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.elu.default] + + ep = ep.run_decompositions( + decomp_table=decomp_table, + ) FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run( ep.graph_module.code ) @@ -7619,12 +8287,17 @@ def forward(self, x): "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) - decomp_table = get_decompositions([torch.ops.aten.upsample_bilinear2d.vec]) - ep = ep.run_decompositions( - decomp_table=decomp_table, - _preserve_ops=[torch.ops.aten.upsample_bilinear2d.vec], - ) - assert torch.ops.aten.upsample_bilinear2d.vec in decomp_table + if IS_FBCODE: + ep = ep.run_decompositions( + _preserve_ops=(torch.ops.aten.upsample_bilinear2d.vec,) + ) + else: + decomp_table = core_aten_decompositions() + del decomp_table[torch.ops.aten.upsample_bilinear2d.vec] + ep = ep.run_decompositions( + decomp_table=decomp_table, + ) + FileCheck().check_count( "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) diff --git a/test/export/test_export_nonstrict.py b/test/export/test_export_nonstrict.py index c368eb069a2bd1..99944e4841b72e 100644 --- a/test/export/test_export_nonstrict.py +++ b/test/export/test_export_nonstrict.py @@ -3,8 +3,8 @@ try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library from torch.export import export diff --git a/test/export/test_export_training_ir_to_run_decomp.py b/test/export/test_export_training_ir_to_run_decomp.py index 7ff8beaa47094c..b1168f54bb2277 100644 --- a/test/export/test_export_training_ir_to_run_decomp.py +++ b/test/export/test_export_training_ir_to_run_decomp.py @@ -1,32 +1,39 @@ # Owner(s): ["oncall: export"] +import torch +from torch.testing._internal.common_utils import IS_FBCODE + try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library -from torch.export._trace import _export_for_training + import testing # @manual=fbcode//caffe2/test:test_export-library test_classes = {} def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): - ep = _export_for_training(*args, **kwargs) - return ep.run_decompositions( - {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY - ) + ep = torch.export.export_for_training(*args, **kwargs) + if IS_FBCODE: + return ep.run_decompositions( + {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY + ) + return ep.run_decompositions({}) def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): if "strict" in kwargs: - ep = _export_for_training(*args, **kwargs) + ep = torch.export.export_for_training(*args, **kwargs) else: - ep = _export_for_training(*args, **kwargs, strict=False) - return ep.run_decompositions( - {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY - ) + ep = torch.export.export_for_training(*args, **kwargs, strict=False) + + if IS_FBCODE: + return ep.run_decompositions( + {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY + ) + return ep.run_decompositions({}) def make_dynamic_cls(cls, strict): diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 7bcd50a7b40cf1..624b528290c9a8 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -1166,20 +1166,19 @@ def forward(self, x): x = torch.randn([3, 3]) ep = export(mod, (x,)) inplace_ep = unsafe_remove_auto_functionalized_pass(ep) - - nodes = inplace_ep.graph.nodes - getitems = 0 - for node in nodes: - if node.op == "call_function": - self.assertFalse(node.target is auto_functionalized) - if node.target is operator.getitem: - getitems += 1 - self.assertEqual(getitems, 2) # tuple return of len 2 - - out_specs = inplace_ep.graph_signature.output_specs - self.assertEqual(out_specs[0].arg.name, "b_state") # state - self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1 - self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2 + graph_text = str(inplace_ep.graph) + self.assertExpectedInline( + graph_text, + """\ +graph(): + %b_state : [num_users=2] = placeholder[target=b_state] + %x : [num_users=1] = placeholder[target=x] + %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\ +default](args = (%x, %b_state), kwargs = {}) + %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {}) + return (b_state, getitem_3, getitem_4)""", + ) @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_move_to_device_pass(self): diff --git a/test/export/test_retraceability.py b/test/export/test_retraceability.py index d3f914188ccfce..e7f243fd9fb7e4 100644 --- a/test/export/test_retraceability.py +++ b/test/export/test_retraceability.py @@ -3,8 +3,8 @@ try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library from torch.export import export diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index 59b83f22c3d23f..a1ced9dd4e5e6c 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -6,8 +6,8 @@ try: from . import test_export, testing except ImportError: - import test_export - import testing + import test_export # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library from torch.export import export, load, save diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 9a2290fcba1164..6f9fa464a86637 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -7,6 +7,7 @@ # Owner(s): ["oncall: export"] import copy import io +import math import tempfile import unittest import zipfile @@ -441,36 +442,6 @@ def forward(self, x): if "aten.sum.dim_IntList" in node.target: self.assertEqual(node.inputs[1].arg.type, "as_ints") - def test_nn_module_stack_serde_with_commas(self) -> None: - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2) - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(self.linear(x)) - - # export the model - ep = torch.export.export(M(), (torch.randn(2, 2),)) - # modify the nn_module_stack to contain comma(,) and semicolon(;) - for node in ep.graph.nodes: - if nn_module_stack := node.meta.get("nn_module_stack"): - for k, (p, t) in nn_module_stack.items(): - nn_module_stack[k] = (p + ";semicolon", t + ",comma") - node.meta["nn_module_stack"] = nn_module_stack - # serialize and deserialize the model - buffer = io.BytesIO() - save(ep, buffer) - buffer.seek(0) - loaded_ep = load(buffer) - # check that the output is the same - inp = (torch.randn(2, 2),) - exp_out = ep.module()(*inp) - actual_out = loaded_ep.module()(*inp) - self.assertEqual(exp_out, actual_out) - self.assertEqual(exp_out.requires_grad, actual_out.requires_grad) - @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") @@ -838,6 +809,26 @@ def f(x, y): self.check_graph(M(), inputs) + def test_arg_from(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("compress_weight", torch.ones((10, 10))) + self.register_buffer("compress_bias", torch.ones(10)) + + def forward(self) -> None: + if self.compress_weight is None or self.compress_bias is None: + return + torch.nn.init.kaiming_uniform_(self.compress_weight, a=math.sqrt(5)) + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( + self.compress_weight + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(self.compress_bias, -bound, bound) + + with torch.no_grad(): + self.check_graph(M(), ()) + def test_map(self): from functorch.experimental import control_flow diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index e4ef525e7cdbe0..d179371389c3f7 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -22,6 +22,7 @@ from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG +from torch.export.unflatten import _disable_interpreter from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_utils import ( @@ -896,6 +897,77 @@ def forward(self, x, y): self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1) self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0) + def test_unflatten_eager(self): + class NestedChild(torch.nn.Module): + def forward(self, x): + return x / x + + class Child1(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.nested = NestedChild() + self.register_parameter( + "child1param", torch.nn.Parameter(torch.ones(2, 3)) + ) + + def forward(self, x): + x = self.nested(x) + return x + self.child1param + + class Child2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) + + def forward(self, x): + return x - self.child2buffer + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.foo = Child1() + self.bar = Child2() + self.register_parameter( + "rootparam", torch.nn.Parameter(torch.ones(2, 3)) + ) + + def forward(self, x): + x = x * self.rootparam + x = self.foo(x) + x = self.bar(x) + return x + + orig_eager = MyModule() + export_module = export(orig_eager, (torch.rand(2, 3),), {}) + with _disable_interpreter(): + unflattened = unflatten(export_module) + + self.assertEqual(unflattened._run_with_interpeter, False) + self.assertEqual(unflattened.foo._run_with_interpeter, False) + + inputs = (torch.rand(2, 3),) + + # Compare the root modules and all submodules + self.compare_outputs(orig_eager, unflattened, inputs) + self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) + self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) + self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) + + # Check state dicts are equal + orig_state_dict = orig_eager.state_dict() + exported_state_dict = unflattened.state_dict() + for name, value in orig_state_dict.items(): + self.assertTrue(torch.allclose(value, exported_state_dict[name])) + + # Check composability with symbolic trace, as torchrec ddp uses symbolic + # tracer + symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs) + self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs))) + + # torch.compile submodule + unflattened.foo = torch.compile(unflattened.foo, fullgraph=True) + self.compare_outputs(orig_eager, unflattened, inputs) + if __name__ == "__main__": run_tests() diff --git a/test/export/testing.py b/test/export/testing.py index 054e3a611dff86..3647d4c9edd86a 100644 --- a/test/export/testing.py +++ b/test/export/testing.py @@ -226,7 +226,7 @@ def _fn(*args, **kwargs): try: from . import test_export except ImportError: - import test_export + import test_export # @manual=fbcode//caffe2/test:test_export-library with patch(f"{test_export.__name__}.export", mocked_export_fn): return fn(*args, **kwargs) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 002fd7691b6cc4..8c438bc2e4fc72 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -145,6 +145,11 @@ ("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)), ("aten::_scaled_mm.out", datetime.date(2024, 12, 31)), ("aten::_scaled_mm", datetime.date(2024, 12, 31)), + ("aten::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), + ("aten::wrapped_linear_prepack", datetime.date(2024, 12, 31)), + ("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)), + ("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)), + ("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), # BC-breaking change in can_cast signature: 'from' -> 'from_' ("aten::can_cast", datetime.date(2024, 5, 31)), ] diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 55383073c4d3d6..45462674fe8eec 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -725,59 +725,6 @@ def forward(self, primals_1): return (add,)""", ) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case") - @unittest.skipIf(IS_MACOS, "TODO: need to fix the test case") - def test_input_mutation_fsdp_set__into_same_input(self): - import torch.distributed._composable.fsdp._fsdp_param - - def f(a): - b = torch.arange(9, dtype=a.dtype).view(3, 3) - c = torch.arange(9, dtype=a.dtype).view(3, 3) - d = torch.arange(9, dtype=a.dtype).view(3, 3) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): - torch.ops.fsdp.set_.default(a, b) - x = a * a - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): - torch.ops.fsdp.set_.default(a, c) - y = a * a - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): - torch.ops.fsdp.set_.default(a, c) - z = a * a - return x + y + z - - inp = [torch.ones(3, 3, requires_grad=True)] - fw_graph = self.verify_aot_autograd( - f, inp, test_mutation=True, keep_inp_mutations=True - ) - inp = [torch.ones(3, 3, requires_grad=False)] - self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) - """ - Expected behavior: - (1) When there are multiple set_() calls on the same graph input primal_X, - we want those set_() calls to all show up with primal_X as the first arg in the graph. - (2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892), - but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior. - """ - self.assertExpectedInline( - fw_graph.code.strip(), - """\ -def forward(self, primals_1): - arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - view = torch.ops.aten.view.default(arange, [3, 3]); arange = None - arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None - set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None - mul = torch.ops.aten.mul.Tensor(primals_1, primals_1) - set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None - mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1) - set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None - mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1) - add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None - add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None - return (add_1, primals_1)""", - ) - self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp)) - def test_input_mutation_simple_with_none_and_nontensor(self): # Tensor, None, int def f(a, b, c): @@ -6010,6 +5957,18 @@ def forward(self, x): with self.assertRaisesRegex(AssertionError, "Unexpected fake"): aot_module_simplified(MockModule(), (fake_x,), nop) + def test_aot_test_subclasses_with_tensor_factories(self): + from torch.testing._internal.common_subclass import SubclassWithTensorFactory + + inp = SubclassWithTensorFactory(torch.zeros(3, 5)) + + def fn(x): + return 2 * x + + ref_out = fn(inp) + out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp) + self.assertEqual(ref_out, out) + # entries in here don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 2798391c4c7f49..cf639b6abcce4e 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -8,6 +8,7 @@ from functorch.experimental import control_flow from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException from torch._higher_order_ops.associative_scan import associative_scan +from torch._higher_order_ops.scan import scan from torch._higher_order_ops.while_loop import while_loop from torch._subclasses.functional_tensor import ( CppFunctionalizeAPI, @@ -19,10 +20,14 @@ from torch.testing._internal.common_cuda import SM70OrLater from torch.testing._internal.common_quantization import skipIfNoDynamoSupport from torch.testing._internal.common_utils import ( + decorateIf, instantiate_parametrized_tests, IS_WINDOWS, parametrize, + requires_cuda, run_tests, + skipIfCrossRef, + skipIfRocm, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO, TestCase, @@ -80,8 +85,8 @@ def _fake_while_loop(cond_fn, body_fn, operands): return operands -def _fake_associative_scan(combine_fn, input, dim, reverse=False): - inp_leaves, spec = pytree.tree_flatten(input) +def _fake_associative_scan(combine_fn, xs, dim, reverse=False): + inp_leaves, spec = pytree.tree_flatten(xs) result_flat = [] num_leaves = len(inp_leaves) op = reversed if reverse else lambda x: x @@ -108,6 +113,118 @@ def _fake_associative_scan(combine_fn, input, dim, reverse=False): return pytree.tree_unflatten(results, spec) +def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False): + carry_leaves, carry_spec = pytree.tree_flatten(init) + inp_leaves, inp_spec = pytree.tree_flatten(xs) + if xs is None or len(inp_leaves) == 0: + return init, [] + result_flat = [] + carry = carry_leaves + op = reversed if reverse else lambda x: x + + dummy_carry, dummy_out = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten( + [torch._ops.ops.aten.slice(elem, dim, 0, 1, 1) for elem in inp_leaves], + inp_spec, + ), + ) + dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out) + num_leaves = len(dummy_out_leaves) + + for ind in op(range(inp_leaves[0].size(dim))): + xs = [ + torch._ops.ops.aten.slice(elem, dim, ind, ind + 1, 1) for elem in inp_leaves + ] + + carry, y = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(xs, inp_spec), + ) + carry, _ = pytree.tree_flatten(carry) + y, _ = pytree.tree_flatten(y) + result_flat.append(y) + + results = [ + torch.concatenate([e[leave_ind] for e in op(result_flat)], dim) + for leave_ind in range(num_leaves) + ] + return ( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(results, dummy_out_spec), + ) + + +def compile_mode_helper(fct, compile_mode): + if compile_mode == "compile": + return torch.compile(fct, fullgraph=True, dynamic=False) + elif compile_mode == "compile_dynamic_shape": + return torch.compile(fct, fullgraph=True, dynamic=True) + elif compile_mode == "eager": + return torch.compile(fct, fullgraph=True, backend="eager") + else: + return fct + + +def get_scan_combine_fn(name, associative=True): + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + def adds(x: torch.Tensor, y: torch.Tensor): + return x + x, y + y + + def mul(x: torch.Tensor, y: torch.Tensor): + return x * y + + def div(x: torch.Tensor, y: torch.Tensor): + return x / y + + def s5_operator(x, y): + A_i, Bu_i = x + A_j, Bu_j = y + return A_j * A_i, A_j * Bu_i + Bu_j + + def tuple_fct(x, y): + return (x[0] + y[0], x[1] * y[1]) + + def complex_pointwise(x, y): + return { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + } + + def non_pointwise(x: torch.Tensor, y: torch.Tensor): + W = torch.diag(torch.ones(2, device=x.device)) + return x @ W + y @ W + + if name == "add": + fct = add + elif name == "adds": + fct = adds + elif name == "mul": + fct = mul + elif name == "div": + fct = div + elif name == "s5_operator": + fct = s5_operator + elif name == "tuple_fct": + fct = tuple_fct + elif name == "complex_pointwise": + fct = complex_pointwise + elif name == "non_pointwise": + fct = non_pointwise + else: + raise ValueError("Combine_fn name unknown!") + + if not associative: + return lambda x, y: (fct(x, y), fct(x, y)) + else: + return fct + + def _while_loop_tests(): def simple(x): def cond_fn(x): @@ -1204,21 +1321,35 @@ def fwbw(map_op, f, x, y): fake_outs = fwbw(_fake_map, f, x, y) self.assertEqual(true_outs, fake_outs) + # TODO: provide an implementation for all compile modes and re-enable all test @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda @parametrize("reverse", [False, True]) - @parametrize("device", [torch.device("cuda")]) - def test_pointwise_associative_scan_reverse_simple(self, reverse, device): - def add(x: torch.Tensor, y: torch.Tensor): - return x + y + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_compile( + self, combine_mode, reverse, compile_mode, device + ): + x = torch.randn(3, 10, 2, device=device) - def mul(x: torch.Tensor, y: torch.Tensor): - return x * y + scan_fct = compile_mode_helper(associative_scan, compile_mode) - x = torch.randn(3, 10, 2, device=device) - for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan(op, x, 0, reverse=reverse) - result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) + for op, op_pt in [ + (get_scan_combine_fn("add", True), torch.cumsum), + (get_scan_combine_fn("mul", True), torch.cumprod), + ]: + result = scan_fct(op, x, 0, reverse=reverse, combine_mode=combine_mode) + result_exp = _fake_associative_scan(op, xs=x, dim=0, reverse=reverse) self.assertEqual(result, result_exp) if not reverse: result_exp_PT = op_pt(x, 0) @@ -1226,8 +1357,16 @@ def mul(x: torch.Tensor, y: torch.Tensor): # Jax Examples x = torch.arange(0, 4, device=device) - cumsum1 = associative_scan(add, x, 0, reverse=reverse) - cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) + cumsum1 = scan_fct( + get_scan_combine_fn("add", True), + x, + 0, + reverse=reverse, + combine_mode=combine_mode, + ) + cumsum_exp = _fake_associative_scan( + get_scan_combine_fn("add", True), x, 0, reverse=reverse + ) if not reverse: self.assertEqual( cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) @@ -1238,18 +1377,191 @@ def mul(x: torch.Tensor, y: torch.Tensor): ) self.assertEqual(cumsum1, cumsum_exp) - @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda @parametrize("reverse", [False, True]) - @parametrize("device", [torch.device("cuda")]) - def test_pointwise_associative_scan_reverse_dim(self, reverse, device): - import random + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_compile(self, reverse, compile_mode, device): + def add2(x: torch.Tensor, y: torch.Tensor): + return x * y, x + y - def add(x: torch.Tensor, y: torch.Tensor): - return x + y + x = torch.randn(3, 10, 2, device=device) - def mul(x: torch.Tensor, y: torch.Tensor): - return x * y + scan_fct = compile_mode_helper(scan, compile_mode) + + for op, op_pt, init in [ + ( + get_scan_combine_fn("add", False), + torch.cumsum, + torch.zeros(1, 10, 2, device=device), + ), + ( + get_scan_combine_fn("mul", False), + torch.cumprod, + torch.ones(1, 10, 2, device=device), + ), + ]: + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + if not reverse: + result_exp_PT = op_pt(x, 0) + self.assertEqual(result[1], result_exp_PT) + + # Jax Examples + x = torch.arange(0, 4, device=device, dtype=torch.int64) + init = torch.zeros(1, device=device, dtype=torch.int64) + cumsum1 = scan_fct( + get_scan_combine_fn("add", False), + init, + x, + dim=0, + reverse=reverse, + ) + cumsum_exp = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=x, + dim=0, + reverse=reverse, + ) + if not reverse: + self.assertEqual( + cumsum1[1], torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) + ) + self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64)) + else: + self.assertEqual( + cumsum1[1], torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) + ) + self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64)) + self.assertEqual(cumsum1, cumsum_exp) + + # Different carry computation as output computation + x = torch.arange(1, 5, device=device, dtype=torch.int64) + init = torch.ones(1, device=device, dtype=torch.int64) + result = scan_fct(add2, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(add2, init=init, xs=x, dim=0, reverse=reverse) + if not reverse: + self.assertEqual( + result[1], torch.tensor([2.0, 3.0, 5.0, 10.0], dtype=torch.int64) + ) + self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64)) + else: + self.assertEqual( + result[1], torch.tensor([25.0, 14.0, 7.0, 5.0], dtype=torch.int64) + ) + self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64)) + self.assertEqual(result, result_exp) + + # Non associative operation + x = torch.arange(0, 5, device=device, dtype=torch.float32) + init = torch.ones(1, device=device, dtype=torch.float32) + result = scan_fct( + get_scan_combine_fn("div", False), + init, + x, + dim=0, + reverse=reverse, + ) + result_exp = _fake_scan( + get_scan_combine_fn("div", False), + init=init, + xs=x, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, result_exp) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + @parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.int32, + torch.int64, + torch.complex64, + ], + ) + def test_scan_dtype(self, reverse, compile_mode, device, dtype): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Check all outputs and carries on the correct device and with torch.float32 + x = torch.randn(3, 10, 2, device=device).to(dtype=dtype) + op, init = ( + get_scan_combine_fn("adds"), + torch.zeros(1, 10, 2, device=device, dtype=dtype), + ) + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual( + [[r.device.type for r in res] for res in result], + [[device.type for _ in res] for res in result], + ) + self.assertEqual( + [[r.dtype for r in res] for res in result], + [[dtype for _ in res] for res in result], + ) + + # Check all outputs and carries on the correct device and + # carry.dtype torch.float32 and output.dtype torch.float16 + x = torch.randn(3, 10, 2, device=device).to(dtype=dtype) + op, init = ( + get_scan_combine_fn("adds"), + torch.zeros(1, 10, 2, device=device, dtype=torch.float32), + ) + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual( + [[r.dtype for r in res] for res in result], + [ + [torch.float32 for _ in range(len(result[0]))], + [dtype for _ in range(len(result[1]))], + ], + ) + + # Check all outputs and carries on the correct device and + # carry.dtype torch.int64 and output.dtype torch.float32 + x = torch.randn(3, 10, 2, device=device) + op, init = ( + get_scan_combine_fn("adds"), + torch.zeros(1, 10, 2, device=device, dtype=dtype), + ) + result = scan_fct(op, init, x, dim=0, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual( + [[r.dtype for r in res] for res in result], + [ + [dtype for _ in range(len(result[0]))], + [torch.float32 for _ in range(len(result[1]))], + ], + ) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_dim(self, combine_mode, reverse, device): + import random num_dims = [random.randint(2, 5) for _ in range(10)] for num_dim in num_dims: @@ -1257,8 +1569,13 @@ def mul(x: torch.Tensor, y: torch.Tensor): rnd_scan_dim = random.randint(0, num_dim - 1) x = torch.randn(*shapes, device=device) - for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan(op, x, rnd_scan_dim, reverse=reverse) + for op, op_pt in [ + (get_scan_combine_fn("add", True), torch.cumsum), + (get_scan_combine_fn("mul", True), torch.cumprod), + ]: + result = associative_scan( + op, x, rnd_scan_dim, reverse=reverse, combine_mode=combine_mode + ) result_exp = _fake_associative_scan( op, x, rnd_scan_dim, reverse=reverse ) @@ -1267,52 +1584,1363 @@ def mul(x: torch.Tensor, y: torch.Tensor): result_exp_PT = op_pt(x, rnd_scan_dim) self.assertEqual(result, result_exp_PT) + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_dim(self, reverse, device): + import random + + num_dims = [random.randint(2, 5) for _ in range(10)] + for num_dim in num_dims: + shapes = [random.randint(1, 10) for _ in range(num_dim)] + rnd_scan_dim = random.randint(0, num_dim - 1) + x = torch.randn(*shapes, device=device) + init_shapes = shapes + init_shapes[rnd_scan_dim] = 1 + + for op, op_pt, init in [ + ( + get_scan_combine_fn("add", False), + torch.cumsum, + torch.zeros(*init_shapes, device=device), + ), + ( + get_scan_combine_fn("mul", False), + torch.cumprod, + torch.ones(*init_shapes, device=device), + ), + ]: + result = scan(op, init, x, dim=rnd_scan_dim, reverse=reverse) + result_exp = _fake_scan( + op, init=init, xs=x, dim=rnd_scan_dim, reverse=reverse + ) + self.assertEqual(result, result_exp) + if not reverse: + result_exp_PT = op_pt(x, rnd_scan_dim) + self.assertEqual(result[1], result_exp_PT) + + @skipIfRocm(msg="Unsupported on ROCM yet") @unittest.skipIf(not SM70OrLater, "triton") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_binary_operator(self, combine_mode, reverse, device): + state_dim = 20 + timesteps = 10 + projected_inputs = torch.randn( + timesteps, state_dim, requires_grad=True, device=device + ) + A = torch.randn(state_dim, requires_grad=True, device=device) + elements = (A.repeat((timesteps, 1)), projected_inputs) + + result1 = associative_scan( + get_scan_combine_fn("s5_operator", True), + elements, + 0, + combine_mode=combine_mode, + reverse=reverse, + ) + expected_result = _fake_associative_scan( + get_scan_combine_fn("s5_operator", True), elements, 0, reverse=reverse + ) + self.assertEqual( + result1, + expected_result, + ) + self.assertEqual([r.device.type for r in result1], [device.type] * len(result1)) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_binary_operator(self, reverse, device): + state_dim = 20 + timesteps = 10 + projected_inputs = torch.randn( + timesteps, state_dim, requires_grad=True, device=device + ) + A = torch.randn(state_dim, requires_grad=True, device=device) + elements = (A.repeat((timesteps, 1)), projected_inputs) + init = tuple( + [torch.ones_like(torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1))] + + [ + torch.zeros_like( + torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1) + ) + ] + ) + + result = scan( + get_scan_combine_fn("s5_operator", False), + init, + elements, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("s5_operator", False), + init=init, + xs=elements, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, expected_result) + + @skipIfRocm(msg="Unsupported on ROCM yet") + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_tuple(self, combine_mode, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + inp = (x, y) + + result1 = associative_scan( + get_scan_combine_fn("tuple_fct", True), + inp, + 0, + reverse=reverse, + combine_mode=combine_mode, + ) + expected_result = _fake_associative_scan( + get_scan_combine_fn("tuple_fct", True), inp, 0, reverse=reverse + ) + self.assertEqual(result1, expected_result) + + @skipIfRocm(msg="Unsupported on ROCM yet") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_tuple(self, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + inp = (x, y) + init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp) + + result_same = scan( + get_scan_combine_fn("tuple_fct", False), + init, + inp, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("tuple_fct", False), + init=init, + xs=inp, + dim=0, + reverse=reverse, + ) + self.assertEqual(result_same, expected_result) + + def fct_different_output_tuple(x, y): + return ((x[0] + y[0], x[1] * y[1]), (x[1] * y[1])) + + inp = (x, y) + init = tuple(torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp) + + result_diff = scan( + fct_different_output_tuple, init, inp, dim=0, reverse=reverse + ) + expected_result = _fake_scan( + fct_different_output_tuple, init=init, xs=inp, dim=0, reverse=reverse + ) + self.assertEqual(result_diff, expected_result) + self.assertEqual(result_diff[1], result_same[1][1]) + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_associative_scan_wrong_pytree(self, device): + def fct_wrong_pytree(x, y): + return { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + } + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + + with self.assertRaisesRegex( + # Should be: RuntimeError, + # r"The number of leaves of the pytree of the output of the operator + # needs to match the lenght of the pytree of the input", + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic") + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_complex_pytree(self, combine_mode, reverse, device): + def fct_pointwise(x, y): + return { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + } + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + + result = associative_scan( + get_scan_combine_fn("complex_pointwise", True), + inp, + 0, + combine_mode=combine_mode, + reverse=reverse, + ) + expected_result = _fake_associative_scan( + get_scan_combine_fn("complex_pointwise", True), inp, 0, reverse=reverse + ) + self.assertEqual(result, expected_result) + + @requires_cuda + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_wrong_pytree(self, device): + # Init and input have same pytree + def fct_wrong_pytree(x, y): + return ( + { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + }, + { + "i": x["i"] * y["j"][0][0], + "k": 0.0, + "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]), + }, + ) + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + inp_flat, inp_spec = pytree.tree_flatten(inp) + init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] + init = pytree.tree_unflatten(init_flat, inp_spec) + + with self.assertRaisesRegex( + # Should be: RuntimeError, + # r"The number of leaves of the pytree of the new carry produced by + # the operator needs to match the length of the pytree of the init", + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result = scan(fct_wrong_pytree, init, inp, dim=0) + + @requires_cuda @parametrize("reverse", [False, True]) - @parametrize("compile_mode", ["compile", "compile_dynamic_shape"]) - @parametrize("device", [torch.device("cuda")]) - def test_pointwise_associative_scan_reverse_compile( - self, reverse, compile_mode, device + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_complex_pytree(self, reverse, device): + # Init and input have same pytree + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + inp_flat, inp_spec = pytree.tree_flatten(inp) + init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] + init = pytree.tree_unflatten(init_flat, inp_spec) + + result = scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("complex_pointwise", False), + init=init, + xs=inp, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_downstream_scan_matmul( + self, combine_mode, compile_mode, reverse, device ): - def add(x: torch.Tensor, y: torch.Tensor): - return x + y + # Chain with matmul + def chain_fct(inp): + W = torch.ones(2, 5, device=device) + o = associative_scan( + get_scan_combine_fn("add", True), + inp, + 1, + reverse=reverse, + combine_mode=combine_mode, + ) + return o @ W - def mul(x: torch.Tensor, y: torch.Tensor): - return x * y + fct_cmp = compile_mode_helper(chain_fct, compile_mode) + + inp = torch.randn(3, 10, 2, device=device) + expected_result = _fake_associative_scan( + get_scan_combine_fn("add", True), inp, 1, reverse=reverse + ) @ torch.ones(2, 5, device=device) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_downstream_scan_scan( + self, combine_mode, compile_mode, reverse, device + ): + # Chain with scan + def chain_fct_same_dim(inp): + o1 = associative_scan( + get_scan_combine_fn("add", True), + inp, + 1, + combine_mode=combine_mode, + reverse=reverse, + ) + o2 = associative_scan( + get_scan_combine_fn("add", True), + o1, + 1, + combine_mode=combine_mode, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_same_dim, compile_mode) + + inp = torch.randn(3, 10, 2, device=device) + + expected_result = _fake_associative_scan( + get_scan_combine_fn("add", True), + _fake_associative_scan( + get_scan_combine_fn("add", True), inp, 1, reverse=reverse + ), + 1, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("combine_mode", ["pointwise", "generic"]) + @parametrize("compile_mode", ["none", "compile", "compile_dynamic_shape"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of combine_mode=pointwise and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: ( + params["combine_mode"] == "pointwise" + and (params["device"] == torch.device("cpu") or torch.version.hip) + ), + ) + def test_associative_scan_downstream_scan_scan_different_dim( + self, combine_mode, compile_mode, reverse, device + ): + # Chain with scan on different dim + def chain_fct_different_dim(inp): + o1 = associative_scan( + get_scan_combine_fn("add", True), + inp, + 1, + combine_mode=combine_mode, + reverse=reverse, + ) + o2 = associative_scan( + get_scan_combine_fn("add", True), + o1, + 0, + combine_mode=combine_mode, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode) + + inp = torch.randn(3, 10, 2, device=device) + expected_result = _fake_associative_scan( + get_scan_combine_fn("add", True), + _fake_associative_scan( + get_scan_combine_fn("add", True), inp, 1, reverse=reverse + ), + 0, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device): + inp = torch.randn(3, 10, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + + for ind in range(2): + # Chain with matmul + def chain_fct(inp): + W = torch.ones(2, 5, device=device) + o = scan( + get_scan_combine_fn("add", False), + init, + inp, + dim=1, + reverse=reverse, + ) + return o[ind] @ W + + fct_cmp = compile_mode_helper(chain_fct, compile_mode) + + expected_result = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[ind] @ torch.ones(2, 5, device=device) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_downstream_scan_scan(self, compile_mode, reverse, device): + inp = torch.randn(3, 10, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + + # Chain with scan + def chain_fct_same_dim(inp): + o1 = scan( + get_scan_combine_fn("add", False), + init, + inp, + dim=1, + reverse=reverse, + ) + o2 = scan( + get_scan_combine_fn("add", False), + init, + o1[1], + dim=1, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_same_dim, compile_mode) + + expected_result = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=_fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[1], + dim=1, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + + # TODO: provide an implementation for all compile modes and re-enable all test + @requires_cuda + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_downstream_scan_scan_dim(self, compile_mode, reverse, device): + inp = torch.randn(3, 10, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + + # Chain with scan on different dim + init2 = torch.randn(1, 10, 2, device=device) + + def chain_fct_different_dim(inp): + o1 = scan( + get_scan_combine_fn("add", False), + init, + inp, + dim=1, + reverse=reverse, + ) + o2 = scan( + get_scan_combine_fn("add", False), + init2, + o1[1], + dim=0, + reverse=reverse, + ) + return o2 + + fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode) + + expected_result = _fake_scan( + get_scan_combine_fn("add", False), + init=init2, + xs=_fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[1], + dim=0, + reverse=reverse, + ) + result1 = fct_cmp(inp) + self.assertEqual(result1, expected_result) + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of associative_scan and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: (params["device"] == torch.device("cpu")), + ) + def test_associative_scan_non_pointwise(self, reverse, device): x = torch.randn(3, 10, 2, device=device) - torch.compiler.reset() - if compile_mode == "compile": - associative_scan_fct = torch.compile( - associative_scan, fullgraph=True, dynamic=False + # Expected to fail, as the pointwise combine_mode does not allow non-pointwise operations + with self.assertRaisesRegex( + Exception, + "For combine_mode='pointwise', the combine_fn needs to be pointwise", + ): + out = associative_scan( + get_scan_combine_fn("non_pointwise", True), + x, + 0, + reverse=reverse, + combine_mode="pointwise", ) - else: - associative_scan_fct = torch.compile( - associative_scan, fullgraph=True, dynamic=True + + @unittest.skipIf(not SM70OrLater, "triton") + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + # Skipping the combination of associative_scan and device=cpu + # as the current implementation of pointwise does only support CUDA device + @decorateIf( + unittest.skip, + lambda params: (params["device"] == torch.device("cpu")), + ) + def test_associative_scan_non_pointwise_generic(self, reverse, device): + x = torch.randn(3, 10, 2, device=device) + result_expected = _fake_associative_scan( + get_scan_combine_fn("non_pointwise", True), x, 0, reverse=reverse + ) + result1 = associative_scan( + get_scan_combine_fn("non_pointwise", True), + x, + 0, + reverse=reverse, + combine_mode="generic", + ) + self.assertEqual(result1, result_expected) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_non_pointwise(self, reverse, device): + x = torch.randn(3, 10, 2, device=device) + init = torch.randn(1, 10, 2, device=device) + result_expected = _fake_scan( + get_scan_combine_fn("non_pointwise", False), + init=init, + xs=x, + dim=0, + reverse=reverse, + ) + + out = scan( + get_scan_combine_fn("non_pointwise", False), + init, + x, + dim=0, + reverse=reverse, + ) + self.assertEqual(out, result_expected) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_compile_cnt(self, reverse, device): + dim = 1 + + from torch._dynamo.testing import CompileCounter + + # Tests rely on automatic_dynamic = True + with torch._dynamo.config.patch(automatic_dynamic_shapes=True): + cnt = CompileCounter() + x = torch.randn(3, 2, 5, device=device) + init = torch.randn(3, 1, 5, device=device) + # First compilation step + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + reverse=reverse, ) + self.assertEqual(cnt.frame_count, 1) - for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]: - result = associative_scan_fct(op, x, 0, reverse=reverse) - result_exp = _fake_associative_scan(op, x, 0, reverse=reverse) - self.assertEqual(result, result_exp) - if not reverse: - result_exp_PT = op_pt(x, 0) - self.assertEqual(result, result_exp_PT) + x = torch.randn(3, 20, 5, device=device) + init = torch.randn(3, 1, 5, device=device) + # Recompilation due to first different size + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 2) - # Jax Examples - x = torch.arange(0, 4, device=device) - cumsum1 = associative_scan_fct(add, x, 0, reverse=reverse) - cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse) + x = torch.randn(3, 40, 5, device=device) + init = torch.randn(3, 1, 5, device=device) + # No recompilation, because of dynamic shape + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 2) + + x = torch.randn(3, 40, 5, device=device) + init = torch.randn(3, 40, 1, device=device) + # Recompilation because of dim change + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=2, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 3) + + x = torch.randn(3, 40, 20, device=device) + init = torch.randn(3, 40, 1, device=device) + # Recompilation due to first different size on new dim + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=2, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 4) + + x = torch.randn(3, 40, 40, device=device) + init = torch.randn(3, 40, 1, device=device) + # No recompilation, because of dynamic shape on new dim + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=2, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 4) + + x = torch.randn(3, 60, 40, device=device) + init = torch.randn(3, 1, 40, device=device) + # Recompilation because of dim change + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 5) + + x = torch.randn(3, 60, 40, device=device) + init = torch.randn(3, 1, 40, device=device) + # Recompilation because of reverse change + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=not reverse, + ) + self.assertEqual(cnt.frame_count, 6) + + x = torch.randn(3, 60, 40, device=device) + init = torch.randn(3, 1, 40, device=device) + # No recompilation, as nothing changed + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=not reverse, + ) + self.assertEqual(cnt.frame_count, 6) + + x = torch.randn(3, 120, 80, device=device) + init = torch.randn(3, 1, 80, device=device) + # No recompilation, final test + torch.compile(scan, backend=cnt)( + get_scan_combine_fn("add", False), + init, + x, + dim=1, + reverse=reverse, + ) + self.assertEqual(cnt.frame_count, 6) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_scanned_0(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Scan dimension is 0 + init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + with self.assertRaisesRegex( + # Should be: RuntimeError, "Input leaves must have a scan dimension > 0" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + inp, + dim=dim, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_non_tensor(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Init is a float and not a tensor + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + init = 1.0 + with self.assertRaisesRegex( + # Should be: RuntimeError, "Init leaves must be a Tensor" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), init, inp, dim=dim, reverse=reverse + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_wrong_shape(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Init wrong shape (Other dim different) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) + init = torch.tile(init, (1, 2, 1)) + with self.assertRaisesRegex( + # Should be: RuntimeError, "The size of tensor a.*" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + inp, + dim=dim, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_wrong_pytree(self, reverse, compile_mode, device): + def add_one_carry(x: torch.Tensor, y: torch.Tensor): + return x[0], x + + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + + # Init wrong pytree + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + init = ( + torch._ops.ops.aten.slice(x, dim, 0, 1, 1), + torch._ops.ops.aten.slice(x, dim, 0, 1, 1), + ) + + with self.assertRaisesRegex( + # Should be: RuntimeError: The number of leaves of the pytree of the new carry produced + # by the operator needs to match the length of the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct(add_one_carry, init, inp, dim=dim, reverse=reverse) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("compile_mode", ["none", "eager"]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init(self, reverse, compile_mode, device): + scan_fct = compile_mode_helper(scan, compile_mode) + + # Only init and no input + x = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 1, 2, device=device) + dim = 1 + op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) + + # Only init given + init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) + result = scan_fct(op, init, [], dim=dim, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=[], dim=dim, reverse=reverse) + result_init = scan_fct(op, init, [], dim=dim, reverse=reverse) + self.assertEqual(result, result_exp) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0], init) + + x = torch.randn(3, 5, 2, device=device) + init = torch.randn(3, 5, 2, device=device) + dim = 0 + + op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + + # Init tensor scalar + init = torch.ones(1, device=device) + + def add_scalar_carry(x: torch.Tensor, y: torch.Tensor): + return x + 1.0, x + y + + result_init = scan_fct(add_scalar_carry, init, inp, dim=dim, reverse=reverse) + result_exp = _fake_scan( + add_scalar_carry, init=init, xs=inp, dim=dim, reverse=reverse + ) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0], torch.tensor([3.0], device=device)) + + # Init tensor entirely different shape than inp + init = torch.randn(7, 8, device=device) + + def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor): + return x + 1.0, x[: y.shape[1], : y.shape[2]] + y + + result_init = scan_fct(add_scalar_carry2, init, inp, dim=dim, reverse=reverse) + result_exp = _fake_scan( + add_scalar_carry2, init=init, xs=inp, dim=dim, reverse=reverse + ) + self.assertEqual(result_init, result_exp) + + # Init with two timestep on dim axis. Should work as y has always 1 on dim axis and + # hence automatic broadcasting should work + # I.e., the input shape is 2x5x2, but the carry at each iteration is 2x5x2, + # thus the output of each iteration is 2x5x2, which results in the total output + # to be 4x5x2 + init = torch._ops.ops.aten.slice(x, dim, 0, 2, 1) + result_init = scan_fct(op, init, inp, dim=dim, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=inp, dim=dim, reverse=reverse) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0].shape, torch.Size([2, 5, 2])) + + init = torch.tile(init, (1, 2, 1)) + + def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): + return x + 1.0, x[:, :1, :] + y + + result_init = scan_fct( + add_scalar_carry_sliced_out, init, inp, dim=dim, reverse=reverse + ) + result_exp = _fake_scan( + add_scalar_carry_sliced_out, init=init, xs=inp, dim=dim, reverse=reverse + ) + self.assertEqual(result_init, result_exp) + self.assertEqual(result_init[0].shape, torch.Size([2, 10, 2])) + self.assertEqual(result_init[1].shape, torch.Size([4, 5, 2])) + + # Correct case + op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) + x = torch.randn(3, 2, 2, device=device) + dim = 1 + + if reverse: + init = torch.zeros_like(torch._ops.ops.aten.slice(x, dim, -1, None, 1)) + inp = torch._ops.ops.aten.slice(x, dim, 0, -1, 1) + else: + init = torch.zeros_like(torch._ops.ops.aten.slice(x, dim, 0, 1, 1)) + inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) + + result = scan_fct(op, init, x, dim=dim, reverse=reverse) + result_exp = _fake_scan(op, init=init, xs=x, dim=dim, reverse=reverse) + + self.assertEqual(result, result_exp) if not reverse: - self.assertEqual( - cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) + result_exp_PT = op_pt(x, dim) + self.assertEqual(result[1], result_exp_PT) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_carry_wrong_pytree(self, reverse, device): + def fct_pointwise_carry_wrong_pytree(x, y): + return ( + ( + x["i"], + { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ), + { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ) + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + inp = {"i": x, "j": ([y], [{"o": z}])} + inp_flat, inp_spec = pytree.tree_flatten(inp) + init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] + init = pytree.tree_unflatten(init_flat, inp_spec) + + # Wrong pytree of the carry produced by the operation + with self.assertRaisesRegex( + # Should be: RuntimeError: The number of leaves of the pytree of the new carry + # produced by the operator needs to match the length of the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result = scan( + fct_pointwise_carry_wrong_pytree, + init, + inp, + dim=0, + reverse=reverse, + ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_wrong_pytree_complex(self, reverse, device): + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + + # Wrong pytree fed to the function + init = { + "i": torch._ops.ops.aten.slice(x, 0, 0, 1, 1), + "j": ( + {"a": torch._ops.ops.aten.slice(x, 0, 0, 1, 1)}, + [torch._ops.ops.aten.slice(y, 0, 0, 1, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, 0, 1, 1)}], + ), + } + inp = { + "i": torch._ops.ops.aten.slice(x, 0, 0, None, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, 0, None, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, 0, None, 1)}], + ), + } + with self.assertRaisesRegex( + Exception, + ".*", + ): + result = scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, ) + + @requires_cuda + @parametrize("reverse", [False, True]) + @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) + def test_scan_init_pytree_complex(self, reverse, device): + def fct_pointwise_different_output(x, y): + return ( + { + "i": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ( + y["i"], + { + "o": x["i"] * y["i"], + "j": ( + [x["j"][0][0] * y["j"][0][0]], + [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ), + ) + + def fct_pointwise_different_carry(x, y): + return ( + { + "i": x["i"] * y["i"], + "j": ( + x["i"], + [x["j"][1][0] * y["j"][0][0]], + [{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ( + y["i"], + { + "o": x["i"] * y["i"] + x["j"][0][0], + "j": ( + [x["j"][1][0] * y["j"][0][0]], + [{"o": x["j"][2][0]["o"] + y["j"][1][0]["o"]}], + ), + }, + ), + ) + + x = torch.randn(3, 2, 2, device=device) + y = torch.randn(3, 2, 2, device=device) + z = torch.randn(3, 2, 2, device=device) + + if reverse: + init_start, init_end = -1, None + inp_start, inp_end = 0, -1 else: - self.assertEqual( - cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) + init_start, init_end = 0, 1 + inp_start, inp_end = 1, None + + # Regular case + init = { + "i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}], + ), + } + inp = { + "i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}], + ), + } + result = scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, + ) + expected_result = _fake_scan( + get_scan_combine_fn("complex_pointwise", False), + init, + inp, + dim=0, + reverse=reverse, + ) + self.assertEqual(result, expected_result) + + # Pytree of output is different + result = scan(fct_pointwise_different_output, init, inp, dim=0, reverse=reverse) + expected_result = _fake_scan( + fct_pointwise_different_output, init=init, xs=inp, dim=0, reverse=reverse + ) + self.assertEqual(result, expected_result) + + # Pytree of carry is different + init = { + "i": torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1), + "j": ( + torch._ops.ops.aten.slice(x, 0, init_start, init_end, 1), + [torch._ops.ops.aten.slice(y, 0, init_start, init_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, init_start, init_end, 1)}], + ), + } + inp = { + "i": torch._ops.ops.aten.slice(x, 0, inp_start, inp_end, 1), + "j": ( + [torch._ops.ops.aten.slice(y, 0, inp_start, inp_end, 1)], + [{"o": torch._ops.ops.aten.slice(z, 0, inp_start, inp_end, 1)}], + ), + } + result = scan(fct_pointwise_different_carry, init, inp, dim=0, reverse=reverse) + expected_result = _fake_scan( + fct_pointwise_different_carry, init=init, xs=inp, dim=0, reverse=reverse + ) + self.assertEqual(result, expected_result) + + def test_scan_RNN(self): + dim = 1 + device = torch.device("cpu") + + rnn = torch.nn.RNN( + input_size=5, + hidden_size=7, + ) + rnn = rnn.to(device=device) + x = torch.randn(1, 2, 5, device=device) + h = torch.randn(1, 2, 7, device=device) + + new_state_dict = { + "weight_ih_l0": torch.ones_like(rnn.weight_ih_l0), + "bias_ih_l0": torch.ones_like(rnn.bias_ih_l0), + "weight_hh_l0": torch.ones_like(rnn.weight_hh_l0), + "bias_hh_l0": torch.ones_like(rnn.bias_hh_l0), + } + rnn.load_state_dict(new_state_dict) + + def RNN(x: torch.Tensor, y: torch.Tensor): + W_ih = torch.ones((5, 7), device=device) + b_ih = torch.ones((7), device=device) + W_hh = torch.ones((7, 7), device=device) + b_hh = torch.ones((7), device=device) + c_new = y @ W_ih + b_ih + h_new = torch.tanh(c_new + x @ W_hh + b_hh) + return h_new, h_new + + expected_result = rnn( + torch.permute(x, (1, 0, 2)), torch.unsqueeze(h[:, 0, :], 0) + ) + expected_result_out = torch.permute(expected_result[0], (1, 0, 2)) + expected_result_state = torch.permute(expected_result[1], (1, 0, 2)) + result = scan(RNN, h[:, 0:1, :], x, dim=dim) + self.assertEqual(result[0], expected_result_state) + self.assertEqual(result[1], expected_result_out) + + @skipIfNoDynamoSupport + def test_scan_simple_graph_no_carry(self): + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Wrong number of returns from function + with self.assertRaisesRegex( + # Should be: RuntimeError: The pytree of the new carry produced + # by the operator needs to match the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + gm = make_fx(f, tracing_mode="symbolic")( + get_scan_combine_fn("add", True), init, x ) - self.assertEqual(cumsum1, cumsum_exp) + + @skipIfNoDynamoSupport + def test_scan_simple_graph_wrong_carry(self): + def add_wrong_carry(x: torch.Tensor, y: torch.Tensor): + return (x + y)[0, :], x + y + + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Wrong carry shape + with self.assertRaisesRegex( + # Should be: RuntimeError: The pytree of the new carry produced by + # the operator needs to match the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + gm = make_fx(f, tracing_mode="symbolic")(add_wrong_carry, init, x) + + @skipIfNoDynamoSupport + def test_scan_simple_graph_wrong_dtype(self): + def add_wrong_dtype(x: torch.Tensor, y: torch.Tensor): + return torch.ones_like(x + y, dtype=torch.int64), x + y + + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Wrong dtype + with self.assertRaisesRegex( + # Should be: RuntimeError: Expected the init and + # the new carry produced by the operator to be a tensor of + # torch.int64 but got torch.float32 and torch.int64 + torch._dynamo.exc.UncapturedHigherOrderOpError, + ".*", + ): + gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) + + @skipIfNoDynamoSupport + @skipIfCrossRef # Arg order changes with crossref + def test_scan_simple_graph(self): + from torch._dynamo.testing import EagerAndRecordGraphs + + x = torch.randn(3, 10, 2, device=torch.device("cpu")) + init = torch.randn(1, 10, 2, device=torch.device("cpu")) + + def f(fct, init, xs): + return scan(fct, init, xs, dim=0, reverse=True) + + # Correct case + gm = make_fx(f, tracing_mode="symbolic")( + get_scan_combine_fn("add", False), init, x + ) + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, fct_1, init_1, xs_1): + slice_1 = torch.ops.aten.slice.Tensor(xs_1, 0, 0, 1) + add = torch.ops.aten.add.Tensor(init_1, slice_1); add = None + add_1 = torch.ops.aten.add.Tensor(init_1, slice_1); slice_1 = add_1 = None + sym_size_int = torch.ops.aten.sym_size.int(init_1, 1) + sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 2) + new_empty = torch.ops.aten.new_empty.default(init_1, [1, sym_size_int, sym_size_int_1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); new_empty = None + new_empty_1 = torch.ops.aten.new_empty.default(xs_1, [1, sym_size_int, sym_size_int_1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); sym_size_int = sym_size_int_1 = new_empty_1 = None + scan_combine_graph_0 = self.scan_combine_graph_0 + scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True); scan_combine_graph_0 = init_1 = xs_1 = None + getitem = scan[0] + getitem_1 = getitem[0]; getitem = None + getitem_2 = scan[1]; scan = None + getitem_3 = getitem_2[0]; getitem_2 = None + return (getitem_1, getitem_3)""", # noqa: B950 + ) + + # Check graph + backend = EagerAndRecordGraphs() + torch.compile(f, backend=backend)(get_scan_combine_fn("add", False), init, x) + gm = backend.graphs[0] + + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor): + l_init_ = L_init_ + l_xs_ = L_xs_ + slice_1 = torch.ops.aten.slice(l_xs_, 0, 0, 1, 1) + out_l = l_init_ + slice_1; out_l = None + add_1 = l_init_ + slice_1; slice_1 = add_1 = None + child = l_init_.new_empty((1, 10, 2), dtype = torch.float32, device = device(type='cpu'), requires_grad = False); child = None + child_1 = l_xs_.new_empty((1, 10, 2), dtype = torch.float32, device = device(type='cpu'), requires_grad = False); child_1 = None + scan_combine_fn_0 = self.scan_combine_fn_0 + scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, True); scan_combine_fn_0 = l_init_ = l_xs_ = None + getitem = scan[0] + getitem_1 = getitem[0]; getitem = None + getitem_2 = scan[1]; scan = None + getitem_3 = getitem_2[0]; getitem_2 = None + return (getitem_1, getitem_3)""", # noqa: B950 + ) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @@ -1362,6 +2990,7 @@ def f(x, y): self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True))) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with crossref def test_cond_simple_with_linear_compile_check_graph(self): from torch._dynamo.testing import EagerAndRecordGraphs @@ -1624,6 +3253,7 @@ def test_while_loop_compile(self, backend, while_loop_test): self._check_compile(fn, inp, backend=backend) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + @skipIfCrossRef # Arg order changes with cross ref def test_while_loop_simple_with_linear_compile_check_graph(self): fn, inp = WHILE_LOOP_TESTS["simple_with_linear"] from torch._dynamo.testing import EagerAndRecordGraphs @@ -3710,6 +5340,76 @@ def cond_fn(x): torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),)) self.assertEqual(cnt.frame_count, 3) + def test_hop_raises_if_not_overriding_call(self): + class WrongHop(torch._ops.HigherOrderOperator): + pass + + with self.assertRaisesRegex(TypeError, "WrongHop"): + wrong_hop = WrongHop("wrong_hop") + + def test_scan_functionalized(self): + def f(init, xs): + return scan(get_scan_combine_fn("add", False), init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + self.assertEqual( + functional_f(example_init, example_inputs), f(example_init, example_inputs) + ) + + # https://github.com/pytorch/pytorch/issues/126988 + @xfailIfTorchDynamo + def test_scan_functionalized_elem_mutation(self): + def add1(x, y): + x.add_(4) + return x + y, x + y + + def f(init, xs): + return scan(add1, init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + with self.assertRaisesRegex( + UnsupportedAliasMutationException, + "Combine_fn might be modifying the input!", + ): + functional_f(example_init, example_inputs) + + def add2(x, y): + y.add_(4) + return x + y, x + y + + def f(init, xs): + return scan(add2, init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + with self.assertRaisesRegex( + UnsupportedAliasMutationException, + "Combine_fn might be modifying the input!", + ): + functional_f(example_init, example_inputs) + + # https://github.com/pytorch/pytorch/issues/126988 + @xfailIfTorchDynamo + def test_scan_functionalized_elem_alias(self): + def add(x, y): + return x, x + + def f(init, xs): + return scan(add, init, xs, dim=1) + + example_inputs = torch.ones(5, 7, 4) + example_init = torch.ones(5, 1, 4) + functional_f = torch.func.functionalize(f) + with self.assertRaisesRegex( + UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!" + ): + functional_f(example_init, example_inputs) + _hop_schema_test_schema_types = [ "bool", diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index ce2f781c08b889..a1bd52a2fbb808 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4993,6 +4993,9 @@ class MySum(HigherOrderOperator): def __init__(self): super().__init__("mysum") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + mysum = MySum() @mysum.py_impl(torch._C._functorch.TransformType.Vmap) @@ -5138,8 +5141,9 @@ def wrapper(*args, **kwargs): @markDynamoStrictTest class TestCompileTransforms(TestCase): @skipIfRocm(msg="test leaks memory on ROCm") + # torch.compile is not supported on Windows CUDA. # Triton only supports GPU with SM70 or later. - @expectedFailureIf(TEST_CUDA and not SM70OrLater) + @expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater)) def test_compile_vmap_hessian(self, device): # The model and inputs are a smaller version # of code at benchmark repo: diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 03744b7a8ef8fe..93e8f23d1ea40e 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1409,18 +1409,6 @@ def test_vmapjvpall(self, device, dtype, op): xfail("nn.functional.soft_margin_loss", ""), xfail("nn.functional.max_unpool1d", "grad"), xfail("nn.functional.embedding", ""), - xfail( - "scatter_reduce", "sum" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "mean" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amin" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amax" - ), # aten::scatter_reduce.two hit the vmap fallback xfail("nn.functional.glu"), xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule xfail("linalg.lu", ""), @@ -1429,6 +1417,7 @@ def test_vmapjvpall(self, device, dtype, op): xfail("masked.cumprod", ""), xfail("renorm"), # hit vmap fallback, which is disabled xfail("t_copy"), + xfail("transpose_copy"), xfail("unsqueeze_copy"), } ), @@ -1491,18 +1480,6 @@ def test(): xfail("nanquantile"), xfail("ormqr"), xfail("put"), - xfail( - "scatter_reduce", "sum" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "mean" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amin" - ), # aten::scatter_reduce.two hit the vmap fallback - xfail( - "scatter_reduce", "amax" - ), # aten::scatter_reduce.two hit the vmap fallback xfail("quantile"), xfail("renorm"), xfail("take"), @@ -1529,7 +1506,6 @@ def test(): xfail("nn.functional.multi_margin_loss", ""), xfail("nn.functional.multilabel_margin_loss", ""), xfail("nn.functional.pdist", ""), - xfail("scatter_reduce", "prod"), xfail("nn.functional.max_unpool1d", ""), xfail("nn.functional.max_unpool3d", ""), xfail("nn.functional.max_unpool3d", "grad"), @@ -1567,6 +1543,7 @@ def test(): "index_fill" ), # aten::_unique hit the vmap fallback which is currently disabled xfail("t_copy"), + xfail("transpose_copy"), xfail("unsqueeze_copy"), } ), @@ -2427,7 +2404,7 @@ def fn(input, weight, bias): tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}), tol1( "nn.functional.conv2d", - {torch.float32: tol(atol=3e-05, rtol=5e-06)}, + {torch.float32: tol(atol=5e-05, rtol=5e-05)}, device_type="cuda", ), tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}), diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index a8358e62796278..1bfc31fe521bb5 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3923,6 +3923,44 @@ def T(*args): in_dims=(2, 1, None), ) + @parametrize("backend", PLATFORM_SPECIFIC_SDPA) + @parametrize("randomness", ["error", "same", "different"]) + def test_randomness(self, device, randomness, backend): + if device == "cpu": + raise unittest.SkipTest("This test is only for CUDA for now") + backend_ctx = sdpa_kernel([backend]) + with backend_ctx: + B = 4 + query = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) + key = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) + value = torch.rand(B, 4, 32, 8, 128, dtype=torch.float16, device=device) + + def f(q, k, v, dropout): + return F.scaled_dot_product_attention(q, k, v, dropout_p=dropout) + + # No matter the randomness mode, dropout=0.0 should pass + vmap( + functools.partial(f, dropout=0.0), + in_dims=(0, 0, 0), + randomness=randomness, + )(query, key, value) + + fail_with_randomness = randomness == "error" + if backend != SDPBackend.MATH: + fail_with_randomness |= randomness == "same" + context = ( + self.assertRaises(RuntimeError) + # We currently don't support randomness == "same", and "error" should always error with randomness + if fail_with_randomness + else contextlib.nullcontext() + ) + with context: + vmap( + functools.partial(f, dropout=0.5), + in_dims=(0, 0, 0), + randomness=randomness, + )(query, key, value) + @allowVmapFallbackUsage def test_inplace_view(self, device): leaf = torch.randn(4, 5, requires_grad=True) @@ -4389,10 +4427,6 @@ def test_vmap_exhaustive(self, device, dtype, op): # TODO: implement batching rule xfail("_batch_norm_with_update"), xfail("histogram"), - xfail("scatter_reduce", "sum"), - xfail("scatter_reduce", "mean"), - xfail("scatter_reduce", "amax"), - xfail("scatter_reduce", "amin"), # `index_put` OpInfo in pytorch/pytorch has # masked index as input which is not supported xfail("index_put", ""), @@ -4408,6 +4442,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("resize_as_"), xfail("take"), xfail("tensor_split"), + xfail("transpose_copy"), xfail("to_sparse"), # TypeError: expected Tensor as element 0 in argument 0, but got float xfail("item"), @@ -4463,20 +4498,15 @@ def test_vmap_exhaustive(self, device, dtype, op): ), # Batching rule not implemented for aten::narrow.Tensor xfail("nn.functional.triplet_margin_loss", ""), xfail("nn.functional.pdist", ""), - xfail("scatter_reduce", "sum"), - xfail("scatter_reduce", "amax"), xfail("nn.functional.max_unpool1d", "grad"), xfail("nn.functional.multi_margin_loss", ""), - xfail("scatter_reduce", "prod"), xfail("nn.functional.multilabel_margin_loss", ""), - xfail("scatter_reduce", "amin"), xfail("nn.functional.max_unpool3d", "grad"), xfail("nn.functional.max_unpool2d", ""), xfail("nn.functional.max_unpool2d", "grad"), xfail("nn.functional.margin_ranking_loss", ""), xfail("nn.functional.max_unpool1d", ""), xfail("nn.functional.soft_margin_loss", ""), - xfail("scatter_reduce", "mean"), xfail("nn.functional.max_unpool3d", ""), xfail("linalg.ldl_solve", "", device_type="cpu"), xfail("chalf", ""), diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index bcf5bbc42c8329..1bff959e3c4f82 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -63,7 +63,6 @@ "aten::diagflat", "aten::divide.out_mode", "aten::divide_.Scalar", - "aten::dropout", "aten::dropout_", "aten::embedding_bag", "aten::embedding_bag.padding_idx", diff --git a/test/fx/test_matcher_utils.py b/test/fx/test_matcher_utils.py index 83db868e4d5a97..f1cb6105b94ff2 100644 --- a/test/fx/test_matcher_utils.py +++ b/test/fx/test_matcher_utils.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.export import export_for_training from torch.fx import symbolic_trace from torch.fx.experimental.proxy_tensor import make_fx @@ -167,13 +168,13 @@ def pattern(x, weight): relu_mul_by_two = relu * 2 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} - from torch._export import capture_pre_autograd_graph - example_inputs = ( torch.randn(1, 3, 3, 3) * 10, torch.randn(3, 3, 3, 3), ) - pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) + pattern_gm = export_for_training( + WrapperModule(pattern), example_inputs + ).module() before_split_res = pattern_gm(*example_inputs) pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) after_split_res = pattern_gm(*example_inputs) @@ -198,17 +199,17 @@ def pattern(x, weight): relu_mul_by_two = relu * 2 return relu, relu_mul_by_two, {"conv": conv, "relu": relu} - from torch._export import capture_pre_autograd_graph - example_inputs = ( torch.randn(1, 3, 3, 3) * 10, torch.randn(3, 3, 3, 3), ) - pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs) + pattern_gm = export_for_training( + WrapperModule(pattern), example_inputs + ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = capture_pre_autograd_graph( + target_gm = export_for_training( WrapperModule(target_graph), example_inputs - ) + ).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map @@ -246,12 +247,10 @@ def forward(self, x): # nn.Parameter is not an allowed output type in dynamo return linear, {"linear": linear, "x": x} - from torch._export import capture_pre_autograd_graph - example_inputs = (torch.randn(3, 5),) - pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs) + pattern_gm = export_for_training(Pattern(), example_inputs).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = capture_pre_autograd_graph(M(), example_inputs) + target_gm = export_for_training(M(), example_inputs).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index a4e14fbfab4488..7f23e706216a10 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -980,3 +980,38 @@ def check_replacement_nodes(self, traced, matches): return len(replacement_nodes_in_graph) self.assertEqual(check_replacement_nodes(self, traced, matches), 2) + + def test_replace_pattern_with_callback(self) -> None: + class M(torch.nn.Module): + def forward(self, x, y): + return torch.add(x, y) + + def pattern(x, y): + return torch.add(x, y) + + def replacement(x, y): + return torch.sub(torch.mul(x, y), y) + + traced = symbolic_trace(M()) + # Return the same replacement graph for all matches, but have it be a unique + # object each time. + matches = subgraph_rewriter.replace_pattern_with_filters( + traced, + pattern, + replacement_callback=lambda *args: symbolic_trace(replacement).graph, + ) + + def check_replacement_nodes(self, traced, matches): + replacement_nodes_in_graph = [ + node + for node in traced.graph.nodes + if node.target in {torch.sub, torch.mul} + ] + replacement_nodes_in_res = [r for m in matches for r in m.replacements] + self.assertEqual( + len(replacement_nodes_in_graph), len(replacement_nodes_in_res) + ) + self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res) + return len(replacement_nodes_in_graph) + + self.assertEqual(check_replacement_nodes(self, traced, matches), 2) diff --git a/test/inductor/custom_ops.cpp b/test/inductor/custom_ops.cpp index 360a2d0b862384..39c1098d95b817 100644 --- a/test/inductor/custom_ops.cpp +++ b/test/inductor/custom_ops.cpp @@ -1,4 +1,4 @@ -#include +#include // @manual=fbcode//caffe2:libtorch #include #include diff --git a/test/inductor/mock_cache.py b/test/inductor/mock_cache.py index b5effefc1fa3af..8db9cc2ba733c0 100644 --- a/test/inductor/mock_cache.py +++ b/test/inductor/mock_cache.py @@ -1,127 +1,110 @@ # Owner(s): ["module: inductor"] +from __future__ import annotations + import contextlib import dataclasses import sys import threading -import unittest.mock -from types import TracebackType -from typing import Callable, Generator, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Optional, Type, TYPE_CHECKING from typing_extensions import override, Self +from unittest.mock import patch import torch from torch._inductor import config from torch._inductor.remote_cache import RemoteCacheBackend -# The cache state is thread-local so if we're running multiple tests at once -# they won't cross contaminate. However - it needs to be "global" because we -# allow code to create new cache clients which refer to the same cache (because -# it's a remote cache). -class _MockCacheState(threading.local): - def __init__(self, name: str): - self.reset() - self._name = name - self._cache = {} - self._clients = {} # Used for Manifold +if TYPE_CHECKING: + from types import TracebackType + - def reset(self): - self.num_init = 0 +@dataclasses.dataclass +class Stats: + num_put: int = 0 + num_get_hit: int = 0 + num_get_miss: int = 0 + + def __iadd__(self, other: Stats) -> Self: + self.num_put += other.num_put + self.num_get_hit += other.num_get_hit + self.num_get_miss += other.num_get_miss + return self + + def reset(self) -> None: self.num_put = 0 self.num_get_hit = 0 self.num_get_miss = 0 - def report(self): - print( - "".join( - [ - f"{self._name} cache: ", - f"init: {self.num_init}, ", - f"puts: {self.num_put}, ", - f"misses: {self.num_get_miss}, ", - f"hits: {self.num_get_hit}, ", - ] - ), - file=sys.stderr, + def __str__(self) -> str: + return "".join( + ( + f"puts: {self.num_put}, ", + f"misses: {self.num_get_miss}, ", + f"hits: {self.num_get_hit}, ", + ) ) -class _MockLocalAutotuneCacheBackend(RemoteCacheBackend): - _state = _MockCacheState("Local") +# The cache states are thread-local so if we're running multiple tests at once +# they won't cross contaminate. However - it needs to be "global" because we +# allow code to create new cache clients which refer to the same cache (because +# it's a remote cache). - def __init__(self): - state = self._state - state.num_init += 1 - @override - def get(self, key: str) -> Optional[bytes]: - assert isinstance(key, str) +class _GlobalStats(Stats, threading.local): + def __init__(self) -> None: + self.autotune = Stats() + self.fx_graph = Stats() + self.triton = Stats() - state = self._state - if key in state._cache: - state.num_get_hit += 1 - return state._cache[key] - else: - state.num_get_miss += 1 + def reset(self) -> None: + self.autotune.reset() + self.fx_graph.reset() + self.triton.reset() - @override - def put(self, key: str, data: bytes) -> None: - assert isinstance(key, str) - assert isinstance(data, bytes) + def update(self, name: str, delta: Stats) -> None: + stat = getattr(self, name) + stat += delta - state = self._state - state.num_put += 1 - state._cache[key] = data + def report(self): + print("Cache Stats:", file=sys.stderr) + print(f" autotune: {self.autotune}", file=sys.stderr) + print(f" fx_graph: {self.fx_graph}", file=sys.stderr) + print(f" triton: {self.triton}", file=sys.stderr) -class _MockRedisRemoteCache: - _state = _MockCacheState("Redis") +global_stats = _GlobalStats() - def __init__(self, *args, **kwargs): - state = self._state - state.num_init += 1 - def get(self, key: Union[bytes, str]) -> Optional[Union[bytes, str, int, float]]: - assert isinstance(key, (bytes, str)) +class MockBackend(RemoteCacheBackend[Any]): + def __init__(self, name: str, cache: Dict[str, object]) -> None: + self._cache = cache + self._name = name - state = self._state + @staticmethod + def with_name(name: str) -> Callable[[], MockBackend]: + cache = {} - if key in state._cache: - state.num_get_hit += 1 - else: - state.num_get_miss += 1 - return state._cache.get(key) + def wrapper() -> MockBackend: + return MockBackend(name, cache) - def set(self, key: Union[bytes, str], data: Union[bytes, str, int, float]) -> None: - assert isinstance(key, (bytes, str)) - assert isinstance(data, (bytes, str, int, float)), type(data) + return wrapper - state = self._state + @override + def get(self, key: str) -> Optional[Any]: + if key in self._cache: + global_stats.update(self._name, Stats(num_get_hit=1)) + return self._cache.get(key) + else: + global_stats.update(self._name, Stats(num_get_miss=1)) + return None - # According to https://redis-py.readthedocs.io/en/stable/commands.html#redis.commands.core.CoreCommands.set - # redis accepts Union[bytes, memoryview, str, int, float] - state.num_put += 1 - state._cache[key] = data + @override + def put(self, key: str, data: Any) -> None: + global_stats.update(self._name, Stats(num_put=1)) + self._cache[key] = data -@dataclasses.dataclass -class CacheDecl: - qname: str - cls: Type[object] - f: Optional[Callable[..., object]] = None - - def patch(self) -> contextlib.AbstractContextManager: - return unittest.mock.patch(self.qname, self.f or self.cls) - - -_CACHES = ( - CacheDecl( - "torch._inductor.runtime.triton_heuristics.LocalAutotuneCache", - _MockLocalAutotuneCacheBackend, - ), - # This causes any mocking test to require 'redis'. - CacheDecl("redis.Redis", _MockRedisRemoteCache), -) - # List of configs for each cache _CACHE_CONFIG_EN = ( "fx_graph_cache", @@ -133,52 +116,6 @@ def patch(self) -> contextlib.AbstractContextManager: class PatchCaches(contextlib.AbstractContextManager): - num_init = 0 - num_put = 0 - num_get_miss = 0 - num_get_hit = 0 - _savedCacheState = {} - - @staticmethod - def get_caches() -> Tuple[CacheDecl, ...]: - if config.is_fbcode(): - from .fb.mock_cache import FB_CACHES - - return _CACHES + FB_CACHES - else: - return _CACHES - - def __init__(self): - self._contexts = [] - for decl in self.get_caches(): - self._contexts.append(decl.patch()) - - @classmethod - def reset(cls): - """ - Reset the patched cache states as well as the PatchCaches - aggregation. - """ - cls.num_init = 0 - cls.num_put = 0 - cls.num_get_miss = 0 - cls.num_get_hit = 0 - - for decl in cls.get_caches(): - decl.cls._state.reset() - - @classmethod - def update(cls): - """ - Update PatchCaches' state with the values from all the patched caches. - """ - cls.num_init = sum(decl.cls._state.num_init for decl in cls.get_caches()) - cls.num_put = sum(decl.cls._state.num_put for decl in cls.get_caches()) - cls.num_get_miss = sum( - decl.cls._state.num_get_miss for decl in cls.get_caches() - ) - cls.num_get_hit = sum(decl.cls._state.num_get_hit for decl in cls.get_caches()) - @classmethod def setUp(cls): # If this test is using PatchCaches then disable all the caches by @@ -190,50 +127,52 @@ def setUp(cls): cls._savedCacheState[name] = getattr(config, name) setattr(config, name, False) - for decl in cls.get_caches(): - if hasattr(decl.cls, "setUp"): - decl.cls.setUp() - @classmethod def tearDown(cls): - for decl in cls.get_caches()[::-1]: - if hasattr(decl.cls, "tearDown"): - decl.cls.tearDown() - # Restore cache defaults for name in _CACHE_CONFIG_EN: delattr(config, name) if name in cls._savedCacheState: setattr(config, name, cls._savedCacheState[name]) - @classmethod - def report(cls): - """ - Report cache state for all patched caches. - """ - for decl in cls.get_caches(): - decl.cls._state.report() - print( - "".join( - [ - "All caches: ", - f"init: {cls.num_init}, ", - f"puts: {cls.num_put}, ", - f"misses: {cls.num_get_miss}, ", - f"hits: {cls.num_get_hit}", - ] - ), - file=sys.stderr, - ) + def __init__(self) -> None: + self._stack = contextlib.ExitStack() def __enter__(self) -> Self: - """ - Start mocking the patched caches. - """ - self.reset() + global_stats.reset() + self._stack.__enter__() + + ctx = patch( + "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + if config.is_fbcode(): + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls", + MockBackend.with_name("triton"), + ) + self._stack.enter_context(ctx) - for ctx in self._contexts: - ctx.__enter__() return self def __exit__( @@ -242,13 +181,7 @@ def __exit__( exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - """ - Stop mocking the patched caches. - """ - for ctx in self._contexts[::-1]: - ctx.__exit__(exc_type, exc_value, traceback) - - self.update() + self._stack.__exit__(exc_type, exc_value, traceback) @contextlib.contextmanager diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 03c80c074cb5f5..fe49c04e8469f7 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -45,12 +45,13 @@ if HAS_CUDA: - import triton + import triton # @manual from torch.testing._internal.triton_utils import ( add_kernel, add_kernel_2d_autotuned, add_kernel_autotuned, + add_kernel_autotuned_weird_param_order, add_kernel_with_optional_param, add_kernel_with_scaling, mul2_inplace_kernel, @@ -75,14 +76,20 @@ ) from .test_torchinductor import copy_tests, requires_multigpu, TestFailure except ImportError: - from test_aot_inductor_utils import AOTIRunnerUtil - from test_control_flow import ( + from test_aot_inductor_utils import ( + AOTIRunnerUtil, # @manual=fbcode//caffe2/test/inductor:aot_inductor_utils-library + ) + from test_control_flow import ( # @manual=fbcode//caffe2/test/inductor:control_flow-library CondModels, prepend_counters, prepend_predicates, WhileLoopModels, ) - from test_torchinductor import copy_tests, requires_multigpu, TestFailure + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + copy_tests, + requires_multigpu, + TestFailure, + ) except (unittest.SkipTest, ImportError) as e: if __name__ == "__main__": sys.exit(0) @@ -1335,15 +1342,19 @@ def forward(self, x, b): return x + b example_inputs = ( - x := torch.randn((3, 2), device=self.device), + torch.randn((3, 2), device=self.device), torch.randn((1, 2), device=self.device), ) - torch._dynamo.mark_dynamic(x, index=0) # Create dynamic symbol + dynamic_shapes = { + "x": {0: Dim("dx"), 1: Dim.STATIC}, + "b": None, + } # Compile & run model where dynamic dim size > 0. so_path: str = AOTIRunnerUtil.compile( Repro(), example_inputs, + dynamic_shapes=dynamic_shapes, ) aot_inductor_module = AOTIRunnerUtil.load("cuda", so_path) aot_inductor_module(*example_inputs) @@ -1744,7 +1755,7 @@ def forward(self, x, y): torch._export.aot_compile(Model(), example_inputs) supported_dtype_of_cpp_wrapper_mock.assert_called_once_with( - torch.float32, self.device == "cuda" + torch.float32, self.device ) def test_consecutive_compiles(self): @@ -2236,6 +2247,27 @@ def forward(self, x, y): dynamic_shapes=dynamic_shapes, ) + def test_triton_kernel_weird_param_order(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + out = torch.empty_like(x) + add_kernel_autotuned_weird_param_order[16,]( + in_ptr0=x, + in_ptr1=x, + n_elements=x.numel(), + out_ptr=out, + ) + return out + + x = torch.randn(16, 16, device=self.device) + self.check_model(Model(), (x,)) + def test_shifted_constraint_ranges(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -2936,29 +2968,24 @@ class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x0, x1, x2, x3): - t = ( - x0.to(torch.float) - + x1.to(torch.float) - + x2.to(torch.float) - + x3.to(torch.float) - ) + def forward(self, x0, x1): + t = x0.to(torch.float) + x1.to(torch.float) return t inputs = [] for dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, - torch.float8_e4m3fnuz, - torch.float8_e5m2fnuz, + # FP8 funz are for AMD + # see https://github.com/pytorch/pytorch/issues/126734 + # torch.float8_e4m3fnuz, + # torch.float8_e5m2fnuz, ): inputs.append(torch.ones(8, 8, 8, dtype=dtype, device=self.device)) dim0 = Dim("s0", min=2, max=1024) dynamic_shapes = { "x0": {0: dim0}, "x1": {0: dim0}, - "x2": {0: dim0}, - "x3": {0: dim0}, } with torch.no_grad(), config.patch( { @@ -3243,7 +3270,9 @@ def forward(self, values, offsets): model, example_inputs_list, dynamic_shapes=dynamic_shapes ) - @common_utils.parametrize("max_autotune", [False, True]) + # max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106 + # @common_utils.parametrize("max_autotune", [False, True]) + @common_utils.parametrize("max_autotune", [False]) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") @@ -3272,6 +3301,40 @@ def forward(self, x, y): Model(), example_inputs, options=dict(max_autotune=max_autotune) ) + @skip_if_no_torchvision + def test_torchvision_transforms_functional_tensor_resize(self): + import torchvision + + # https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/ + class A(torch.nn.Module): + def forward(self, image: torch.Tensor, target_size: torch.Tensor): + target_h, target_w = target_size.tolist() + torch._check(target_h > 0) + torch._check(target_w > 0) + torch._check(target_h <= 4000) + torch._check(target_w <= 4000) + + return torchvision.transforms._functional_tensor.resize( + image, + size=[target_h, target_w], + interpolation="bilinear", + antialias=False, + ) + + model = A() + example_inputs = ( + torch.ones([3, 800, 600], device=self.device), + torch.tensor([448, 336], device=self.device), + ) + dynamic_shapes = { + "image": { + 1: torch.export.Dim("height", min=1, max=4000), + 2: torch.export.Dim("width", min=1, max=4000), + }, + "target_size": None, + } + self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes) + def test_aoti_debug_printer_codegen(self): # basic addmm model to test codegen for aoti intermediate debug printer class Model(torch.nn.Module): @@ -3302,8 +3365,8 @@ def forward(self, a): ] ) - # test the default debug printing codegen - with config.patch({"aot_inductor.debug_intermediate_value_printer": 1}): + # test default debug printing all tensor values codegen + with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): result, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) @@ -3323,11 +3386,11 @@ def forward(self, a): count, ).run(code) - # test the filtered kernel names printing codegen + # test printing selected kernel's tensor values codegen filtered_kernel_name = f"aoti_torch_{self.device}_addmm_out" with config.patch( { - "aot_inductor.debug_intermediate_value_printer": 1, + "aot_inductor.debug_intermediate_value_printer": "2", "aot_inductor.filtered_kernel_names": filtered_kernel_name, } ): @@ -3356,6 +3419,59 @@ def forward(self, a): FileCheck().check_not(f"before_launch - {kernel_name}").run(code) FileCheck().check_not(f"after_launch - {kernel_name}").run(code) + def test_aoti_debug_printer_user_defined_triton_kernel(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + out = torch.zeros_like(x) + add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16) + return out + + example_inputs = ( + torch.randn(4, 4, device=self.device), + torch.randn(4, 4, device=self.device), + ) + + kernel_calls = [ + ("add_kernel_0", 3), + ] + + with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): + result, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, Model(), example_inputs + ) + # check the c shim print_tensor_handle call is triggered by the config and injected the cpp output code as expected + self.assertEqual("aoti_torch_print_tensor_handle" in code, True) + # check the codegen for debug printing around the actual kernel call is expected + for kernel_call, count in kernel_calls: + FileCheck().check_count( + f"before_launch - {kernel_call}", + count, + ).run(code) + FileCheck().check_count( + f"after_launch - {kernel_call}", + count, + ).run(code) + + def test_size_from_multi_output(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + _x, _i = torch.unique(x, sorted=True, return_inverse=True) + _x = _x.clone().detach() + return self.relu(_x), _i + + example_inputs = (torch.randn(8, device=self.device),) + self.check_model(Model(), example_inputs) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -3544,6 +3660,8 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_custom_op_missing_arg_with_default_value": fail_minimal_arrayref_interface( is_skip=True ), + "test_size_from_multi_output": fail_stack_allocation(is_skip=True), + "test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(), } # test_failures, xfail by default, set is_skip=True to skip @@ -3562,8 +3680,6 @@ def fail_non_abi_compatible_cuda(is_skip=False): "test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True), # fp8 to be re-enabled for AOTI "test_fp8": fail_cuda(is_skip=True), - # non-abi compatible mode debug printer is not supported yet - "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True), "test_custom_op_all_inputs": fail_non_abi_compatible_cuda(is_skip=True), "test_custom_op_missing_arg_with_default_value": fail_non_abi_compatible_cuda( is_skip=True @@ -3573,6 +3689,11 @@ def fail_non_abi_compatible_cuda(is_skip=False): is_skip=True ), "test_custom_op_with_multiple_outputs": fail_non_abi_compatible_cuda(is_skip=True), + # non-abi compatible mode aoti debug printer is not supported yet + "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True), + "test_aoti_debug_printer_user_defined_triton_kernel": fail_non_abi_compatible_cuda( + is_skip=True + ), } @@ -3629,6 +3750,9 @@ def fail_non_abi_compatible_cuda(is_skip=False): CUDA_TEST_FAILURES.update( { "test_aoti_debug_printer_codegen": fail_cuda(is_skip=True), + "test_aoti_debug_printer_user_defined_triton_kernel": fail_cuda( + is_skip=True + ), } ) diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 0e20045cdbfc92..490b0e0324733c 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -1,78 +1,95 @@ # Owner(s): ["module: inductor"] import copy import sys +import tempfile import unittest +from parameterized import parameterized_class + import torch -from torch._inductor import config -from torch._inductor.package import load_package +from torch._inductor.package import AOTICompiledModel, load_package, package_aoti from torch._inductor.test_case import TestCase -from torch.testing._internal import common_utils +from torch.export import Dim from torch.testing._internal.common_utils import IS_FBCODE from torch.testing._internal.triton_utils import HAS_CUDA -try: - try: - from .test_torchinductor import copy_tests - except ImportError: - from test_torchinductor import copy_tests -except (unittest.SkipTest, ImportError) as e: - if __name__ == "__main__": - sys.exit(0) - raise - - -def compile(model, example_inputs, dynamic_shapes, options, device): +def compile( + model, + args, + kwargs=None, + *, + dynamic_shapes=None, + package_path=None, + inductor_configs=None, +) -> AOTICompiledModel: ep = torch.export.export( model, - example_inputs, + args, + kwargs, dynamic_shapes=dynamic_shapes, strict=False, ) - gm = ep.module() - package_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] - compiled_model = load_package(package_path, device) - return compiled_model - - -def check_model( - self: TestCase, - model, - example_inputs, - options=None, - dynamic_shapes=None, - disable_constraint_solver=False, - atol=None, - rtol=None, -): - with torch.no_grad(), config.patch( - { - "aot_inductor.package": True, - # TODO: "aot_inductor.force_mmap_weights": True, - } - ): - torch.manual_seed(0) - model = model.to(self.device) - ref_model = copy.deepcopy(model) - ref_inputs = copy.deepcopy(example_inputs) - expected = ref_model(*ref_inputs) - - torch.manual_seed(0) - compiled_model = compile( - model, - example_inputs, - dynamic_shapes, - options, - self.device, - ) - - actual = compiled_model(*example_inputs) + package_path = torch._inductor.aoti_compile_and_package( + ep, args, kwargs, package_path=package_path, inductor_configs=inductor_configs + ) # type: ignore[arg-type] + loaded = load_package(package_path) + return loaded - self.assertEqual(actual, expected, atol=atol, rtol=rtol) +@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") +@unittest.skipIf(IS_FBCODE, "This is for OSS only") +@parameterized_class( + [ + {"device": "cpu", "package_cpp_only": False}, + {"device": "cpu", "package_cpp_only": True}, + ] + + ( + [ + {"device": "cuda", "package_cpp_only": False}, + {"device": "cuda", "package_cpp_only": True}, + ] + if sys.platform != "darwin" + else [] + ), + class_name_func=lambda cls, _, params: f"{cls.__name__}{'Cpp' if params['package_cpp_only'] else ''}_{params['device']}", +) +class TestAOTInductorPackage(TestCase): + def check_model( + self: TestCase, + model, + example_inputs, + inductor_configs=None, + dynamic_shapes=None, + disable_constraint_solver=False, + atol=None, + rtol=None, + ) -> AOTICompiledModel: + with torch.no_grad(): + torch.manual_seed(0) + model = model.to(self.device) + ref_model = copy.deepcopy(model) + ref_inputs = copy.deepcopy(example_inputs) + expected = ref_model(*ref_inputs) + + inductor_configs = inductor_configs or {} + inductor_configs["aot_inductor.package_cpp_only"] = self.package_cpp_only + + torch.manual_seed(0) + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + compiled_model = compile( + model, + example_inputs, + dynamic_shapes=dynamic_shapes, + inductor_configs=inductor_configs, + package_path=f.name, + ) + + actual = compiled_model(*example_inputs) + + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + return compiled_model -class AOTInductorTestsTemplate: def test_add(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -99,34 +116,84 @@ def forward(self, x, y): ) self.check_model(Model(), example_inputs) + def test_metadata(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x, y): + return x + self.linear(y) -common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) + example_inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, 10, device=self.device), + ) + metadata = {"dummy": "moo"} + compiled_model = self.check_model( + Model(), + example_inputs, + inductor_configs={"aot_inductor.metadata": metadata}, + ) + loaded_metadata = compiled_model.get_metadata() # type: ignore[attr-defined] -@unittest.skipIf(sys.platform == "darwin" or IS_FBCODE, "No CUDA on MacOS") -class AOTInductorTestPackagedABICompatibleCuda(TestCase): - device = "cuda" - check_model = check_model + self.assertEqual(loaded_metadata.get("dummy"), "moo") + def test_multiple_methods(self): + options = { + "aot_inductor.package": True, + "aot_inductor.package_cpp_only": self.package_cpp_only, + } -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestPackagedABICompatibleCuda, - "packaged_abi_compatible_cuda", -) + class Model1(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + def forward(self, a, b): + return torch.cat([a, b], dim=0) -@unittest.skipIf(IS_FBCODE, "This is for OSS only") -class AOTInductorTestPackagedABICompatibleCpu(TestCase): - device = "cpu" - check_model = check_model + b = torch.randn(3, 4, device=self.device) + dim0_a = Dim("dim0_a", min=1, max=10) + dim0_b = Dim("dim0_b", min=1, max=20) + dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_b}} + example_inputs1 = ( + torch.randn(2, 4, device=self.device), + torch.randn(3, 4, device=self.device), + ) + ep1 = torch.export.export( + Model1(), example_inputs1, dynamic_shapes=dynamic_shapes + ) + aoti_files1 = torch._inductor.aot_compile( + ep1.module(), example_inputs1, options=options + ) + + class Model2(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + def forward(self, x): + t = torch.tensor(x.size(-1), device=self.device, dtype=torch.float) + t = torch.sqrt(t * 3) + return x * t + + example_inputs2 = (torch.randn(5, 5, device=self.device),) + ep2 = torch.export.export(Model2(self.device), example_inputs2) + aoti_files2 = torch._inductor.aot_compile( + ep2.module(), example_inputs2, options=options + ) + + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + package_path = package_aoti( + f.name, {"model1": aoti_files1, "model2": aoti_files2} + ) + loaded1 = load_package(package_path, "model1") + loaded2 = load_package(package_path, "model2") + + self.assertEqual(loaded1(*example_inputs1), ep1.module()(*example_inputs1)) + self.assertEqual(loaded2(*example_inputs2), ep2.module()(*example_inputs2)) -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestPackagedABICompatibleCpu, - "packaged_abi_compatible_cpu", -) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 27433876c9ca45..425e87bcfbb07d 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -67,7 +67,7 @@ def compile( @staticmethod def load_runner(device, so_path): if IS_FBCODE: - from .fb import test_aot_inductor_model_runner_pybind + from .fb import test_aot_inductor_model_runner_pybind # @manual with tempfile.TemporaryDirectory() as temp_dir: # copy *.so file to a unique path just before loading diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py new file mode 100644 index 00000000000000..019e88cf1a188e --- /dev/null +++ b/test/inductor/test_auto_functionalize.py @@ -0,0 +1,972 @@ +# Owner(s): ["module: functionalization"] + +import numpy as np + +import torch +import torch._dynamo.testing +import torch._inductor.config as inductor_config +import torch._inductor.test_case +import torch.onnx.operators +import torch.utils._pytree as pytree +import torch.utils.cpp_extension +from torch import Tensor +from torch.testing._internal.logging_utils import logs_to_string + + +class AutoFunctionalizeTests(torch._inductor.test_case.TestCase): + def test_auto_functionalize_can_with_default(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + def foo_impl(a, b, c=None, d=None, e=-1): + a + b + return + + def f(a, mode): + return torch.ops.mylib.foo( + a, + 0, + ) + + a = torch.tensor([10, 10, 10], dtype=torch.int64) + + torch.compile(f)(a, 0) + + def test_auto_functionalize_can_with_none_return(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo(Tensor x, Tensor(a!) out) -> None") + + def foo_impl(x, out): + out.copy_(x) + + lib.impl("foo", foo_impl, "CompositeExplicitAutograd") + x = torch.randn(3) + out = torch.zeros(3) + + @torch.compile + def f(x, out): + torch.ops.mylib.foo(x, out) + + f(x, out) + + def test_auto_functionalize_self_as_mutate_arg(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo(Tensor(a!) self) -> None") + + def foo_impl(self: torch.Tensor) -> None: + self.sin_() + + x = torch.randn(3) + lib.impl("foo", foo_impl, "CompositeExplicitAutograd") + + @torch.compile(backend="inductor", fullgraph=True) + def f(x): + torch.ops.mylib.foo(x) + + f(x) + + def test_auto_functionalize_tensorlist(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out): + for o in out: + o.copy_(all_gather_output) + + def f(all_gather_output, all_gather_input_split_sizes, dim, out): + torch.ops.mylib.foo( + all_gather_output, all_gather_input_split_sizes, dim, out + ) + + a = torch.ones(4) + b = [2, 3] + c = 0 + d = [torch.empty(4) for _ in range(2)] + orig_args = (a, b, c, d) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + def test_can_auto_functionalize(self): + from torch._higher_order_ops.auto_functionalize import can_auto_functionalize + + expected_true = [ + "(Tensor(a!) x) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)", + ] + expected_false = [ + "(Tensor x) -> ()", + "(Tensor(a) x) -> Tensor(a)", + "(Tensor(a!) x) -> Tensor(a!)", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", + "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))", + "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])", + ] + for schema in expected_true: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define("mylib::a", schema, lib=lib) + + self.assertTrue( + can_auto_functionalize(torch.ops.mylib.a.default), msg=schema + ) + self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) + + for schema in expected_false: + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define("mylib::a", schema, lib=lib) + self.assertFalse( + can_auto_functionalize(torch.ops.mylib.a.default), msg=schema + ) + self.assertFalse(can_auto_functionalize(torch.ops.mylib.a)) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_auto_functionalize_old(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # Check the graph under static shapes + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \ +"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \ +arg3_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_auto_functionalize_with_returns_old(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + return y[0] + w, y[1] + n + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_abstract(x, y, z, w, n): + return y[0] + w, y[1] + n + + def f(x, y, z, n): + return torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( + *compiled_args + ) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None + return (getitem_4, getitem_5)""", # noqa: B950 + ignore_comments=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + eager_out = f(*eager_args) + self.assertEqual(compiled_args, eager_args) + self.assertEqual(compiled_out, eager_out) + + def test_auto_functionalize_on_view(self): + for value in [True, False]: + with torch.library._scoped_library( + "mylib", "FRAGMENT" + ) as lib, inductor_config.patch({"enable_auto_functionalized_v2": value}): + torch.library.define( + "mylib::foo", + "(Tensor(a!) x) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x): + x_np = x.detach().numpy() # view + np.sin(x_np, out=x_np) + return + + x = torch.randn(3) + expected = x.sin() + torch.ops.mylib.foo(x) + assert torch.allclose(x, expected) + + @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) + def f(x): + x = x.clone() + y = x[:] + torch.ops.mylib.foo(y) + return x + + y = f(x) + self.assertEqual(y, x.sin()) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_auto_functionalize_optional_old(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + if x is not None: + x.add_(y[0] + w) + if z is not None: + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = None + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): + # No stacktrace found for following nodes + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ +arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + def test_unbacked_auto_functionalize_op(self): + @torch.library.custom_op( + "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"] + ) + def mk_image(decoder: Tensor) -> Tensor: + return torch.randn(2, 3, 4, 5) + + @torch.library.register_fake("mylib::mk_image") + def _(decoder: Tensor) -> Tensor: + image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)] + return torch.empty(image_size) + + @torch.compile(fullgraph=True) + def f(x): + return torch.ops.mylib.mk_image.default(x) + + x = torch.zeros(100, dtype=torch.int64) + f(x) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_v2(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", dynamic=_dynamic, fullgraph=True)( + *compiled_args + ) + + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + def run_aot_eager(self, f, orig_args, _dynamic=False): + aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + ) + + result = None + with ctx(): + result = torch.compile( + f, backend="aot_eager", fullgraph=True, dynamic=_dynamic + )(*aot_eager_args) + + graph = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + return [aot_eager_args, result, graph] + + def run_inductor(self, f, orig_args, _dynamic=False): + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + result = None + with ctx(): + result = torch.compile( + f, backend="inductor", fullgraph=True, dynamic=_dynamic + )(*compiled_args) + + graph = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() + + return [compiled_args, result, graph] + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_with_returns_v2(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + x.add_(y[0] + w) + z.add_(y[1] + n) + return y[0] + w, y[1] + n + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_abstract(x, y, z, w, n): + return y[0] + w, y[1] + n + + def f(x, y, z, n): + return torch.ops.mylib.foo(x, y, z, 2, n) + + x = torch.randn(3) + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_out = torch.compile(f, backend="inductor", fullgraph=True)( + *compiled_args + ) + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None + getitem_4: "f32[3][1]cpu" = foo_default[0] + getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None + + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None + return (getitem_4, getitem_5)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + eager_out = f(*eager_args) + self.assertEqual(compiled_args, eager_args) + self.assertEqual(compiled_out, eager_out) + + # foo takes two inputs that are not views. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra1(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x, y): + torch.ops.mylib.foo(x, y) + return x + y + + orig_args = (torch.randn(2), torch.randn(2)) + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg2_1, arg1_1]) + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[s0][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_1); arg2_1 = getitem_1 = copy__1 = None + return (add,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [arg1_1, arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_2); arg0_1 = getitem_2 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy__1 = None + return (add,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg2_1, arg1_1); foo_default = None + add: "f32[s0][1]cpu" = torch.ops.aten.add.Tensor(arg2_1, arg1_1) + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None + return (add,)""", + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): + foo_default = torch.ops.mylib.foo.default(arg1_1, arg0_1); foo_default = None + add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(arg1_1, arg0_1) + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None + return (add,)""", + ignore_comments=True, + ignore_empty_lines=True, + ) + + # foo takes two views on the same input, function does not have return. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra2(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = x[0] + b = x[1] + torch.ops.mylib.foo(a, b) + return + + orig_args = [torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1]) + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1) + foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) + foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # foo takes two views on the same input, function returns both views and the input + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra3(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = x[0] + b = x[1] + torch.ops.mylib.foo(a, b) + return (a, b, x) + + orig_args = [torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) + [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None + select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0) + select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None + return (select_2, select_3)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) + foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) + select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None + return (select_2, select_3)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # foo takes a mutable list with views in addition to other args. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_extra4(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!)[] y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y[0].sin_() + + def f(x, y, z): + a = x[0] + b = z[0] + torch.ops.mylib.foo(a, [b, y]) + + orig_args = [torch.randn(2), torch.randn(2), torch.randn(2)] + + [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) + [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args[2], eager_args[2]) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_length = 2, _y_0_base_index = 1, _y_0_size = (), _y_0_stride = (), _y_0_storage_offset = 0, _y_1_base_index = 2, _all_bases = [arg0_1, arg1_1, arg2_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2] + getitem_3: "f32[2][1]cpu" = auto_functionalized_v2[3]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None + copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2][1]cpu"): + as_strided_default: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 0) + as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 0) + foo_default = torch.ops.mylib.foo.default(as_strided_default, [as_strided_default_1, arg2_1]); as_strided_default = as_strided_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None + copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None + copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_auto_functionalize_optional_v2(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y, z, w, n): + if x is not None: + x.add_(y[0] + w) + if z is not None: + z.add_(y[1] + n) + + def f(x, y, z, n): + torch.ops.mylib.foo(x, y, z, 2, n) + + x = None + y = (torch.randn(3), torch.randn(3)) + z = torch.randn(3) + n = torch.randn(3) + orig_args = (x, y, z, n) + + compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args) + + if torch._dynamo.config.assume_static_by_default: + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + self.assertExpectedInline( + post_grad_graphs, + """\ +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): + foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None + copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None + return ()""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + f(*eager_args) + self.assertEqual(compiled_args, eager_args) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=False) + def test_inference_mode1_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra1() + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_inference_mode2_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra2() + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_inference_mode3_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra3() + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_inference_mode4_v2(self): + with torch.inference_mode(): + self.test_auto_functionalize_extra4() + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_dynamic_v2(self): + self.test_auto_functionalize_v2(_dynamic=True) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_dynamic2_v2(self): + self.test_auto_functionalize_extra1(_dynamic=True) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_dynamic3_v2(self): + self.test_auto_functionalize_extra2(_dynamic=True) + + # foo takes two views on the same input, function does not have return. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_graph_input_is_view(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x): + pass + + @torch.compile(fullgraph=True, dynamic=False, backend="aot_eager") + def f(x): + a = x[0] + torch.ops.mylib.foo(a) + return + + x = torch.tensor([[1, 2], [3, 4]]) + # This would fail if auto_functionalized_v2 uses clone and not clone_preserve_strides + # to clone not-inplaced args. + f(x[1]) + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + run_tests() diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 9011e11613cd01..9eb25aa305a1a5 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -4,15 +4,12 @@ import sys import torch +from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.test_operators import realize from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - slowTest, - TEST_WITH_ASAN, -) +from torch.testing._internal.common_utils import slowTest, TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA @@ -23,22 +20,15 @@ import contextlib import unittest +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + copy_tests, +) from torch._inductor import config from torch._inductor.scheduler import Scheduler -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - - -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests - - class TestCase(InductorTestCase): @classmethod def setUpClass(cls): @@ -182,6 +172,14 @@ def foo(m, inp): "empty_strided_cuda", 2, exactly=True ).check("return").run(c) + def test_tield_kernel_fusion(self): + def f(x): + y = realize(x + x.t()) + return y + 1 + + x = torch.randn(1024, 1024, device=self.device) + self.common(f, (x,)) + if HAS_CUDA and not TEST_WITH_ASAN: @@ -191,6 +189,37 @@ class BenchmarkFusionCudaTest(TestCase): copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda") + class BenchmarkingTest(TestCase): + @unittest.skipIf( + torch.cuda.device_count() < 2, "The test need at least 2 devices" + ) + def test_benchmark_on_non_zero_device(self): + hit_count = 0 + with torch.cuda.device("cuda:0"): + + @torch.compile + def relu(x): + return realize(x.relu()) + x + + x = torch.randn(int(16e6), device="cuda:1") + + orig_benchmark_fused_nodes = TritonScheduling.benchmark_fused_nodes + + def mock_benchmark_fused_nodes(*args, **kwargs): + nonlocal hit_count + hit_count += 1 + ms, path = orig_benchmark_fused_nodes(*args, **kwargs) + self.assertTrue(ms > 0) + return ms, path + + with unittest.mock.patch.object( + TritonScheduling, + "benchmark_fused_nodes", + mock_benchmark_fused_nodes, + ): + relu(x) + self.assertTrue(hit_count > 0) + class BenchmarkMultiTemplateFusionCudaTest(InductorTestCase): @classmethod def setUpClass(cls): diff --git a/test/inductor/test_binary_folding.py b/test/inductor/test_binary_folding.py index 045b330d02456c..20f613fc746f37 100644 --- a/test/inductor/test_binary_folding.py +++ b/test/inductor/test_binary_folding.py @@ -4,7 +4,6 @@ import itertools import os import sys -import unittest import torch from torch import nn @@ -16,22 +15,18 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN +from inductor.test_inductor_freezing import ( + TestCase, # @manual=fbcode//caffe2/test/inductor:inductor_freezing-library +) +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_gpu, + copy_tests, +) +from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import skipCUDAIf -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - -from inductor.test_inductor_freezing import TestCase -from inductor.test_torchinductor import check_model, check_model_gpu, copy_tests - - importlib.import_module("functorch") importlib.import_module("filelock") diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index 94e53ad1b3563a..dc507e42aeb315 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -43,7 +43,7 @@ def setUp(self): torch.random.manual_seed(1234) try: - import ck4inductor + import ck4inductor # @manual self.ck_dir = os.path.dirname(ck4inductor.__file__) os.environ["TORCHINDUCTOR_CK_DIR"] = self.ck_dir diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 4b7f1039611d36..516d16fba7869d 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -41,9 +41,9 @@ try: - from .mock_cache import patch_fbcode, PatchCaches + from .mock_cache import global_stats, patch_fbcode, PatchCaches except ImportError: - from mock_cache import PatchCaches # @manual + from mock_cache import global_stats, patch_fbcode, PatchCaches # @manual HAS_TRITON = has_triton() @@ -160,6 +160,7 @@ def fn(x, y): self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 1) @requires_triton() + @config.patch({"fx_graph_remote_cache": True}) @parametrize("device", (GPU_TYPE, "cpu")) @parametrize("dtype", (torch.float32, torch.bfloat16)) @parametrize("dynamic", (False, True)) @@ -179,7 +180,6 @@ def fn(x, y): with config.patch( { - "fx_graph_cache": False, "fx_graph_remote_cache": True, } ), patch.dict(os.environ), PatchCaches(): @@ -190,10 +190,10 @@ def fn(x, y): self.assertEqual(fn(a, b), compiled_fn(a, b)) reset() - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 3) - self.assertEqual(PatchCaches.num_get_miss, 1) - self.assertEqual(PatchCaches.num_put, 1) + global_stats.report() + self.assertEqual(global_stats.fx_graph.num_get_hit, 3) + self.assertEqual(global_stats.fx_graph.num_get_miss, 1) + self.assertEqual(global_stats.fx_graph.num_put, 1) @requires_triton() @config.patch({"fx_graph_cache": True}) @@ -793,6 +793,8 @@ def reset(self): torch._dynamo.reset() clear_inductor_caches() + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not SM80OrLater, "Requires SM80+") @config.patch({"fx_graph_cache": False}) @config.patch({"fx_graph_remote_cache": False}) @config.patch({"autotune_local_cache": False}) @@ -800,9 +802,6 @@ def reset(self): @config.patch({"max_autotune": True}) @parametrize("fbcode", (False,) + (True,) * config.is_fbcode()) def test_autotune_cache(self, fbcode: bool): - if not fbcode: - self.skipTest("Redis for autotune is currently broken") - class Model(torch.nn.Module): def forward(self, x, y, a, b): return x + y, a + b @@ -819,18 +818,17 @@ def f(x, y, a, b): with PatchCaches(), patch_fbcode(fbcode): f_compiled(x, y, a, b) - PatchCaches.update() - self.assertEqual(PatchCaches.num_get_hit, 0) - self.assertEqual(PatchCaches.num_get_miss, 2) - self.assertEqual(PatchCaches.num_put, 2) + self.assertEqual(global_stats.autotune.num_get_hit, 0) + self.assertEqual(global_stats.autotune.num_get_miss, 2) + self.assertEqual(global_stats.autotune.num_put, 2) self.reset() f_compiled(x, y, a, b) - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 2) - self.assertEqual(PatchCaches.num_get_miss, 2) - self.assertEqual(PatchCaches.num_put, 2) + global_stats.report() + self.assertEqual(global_stats.autotune.num_get_hit, 2) + self.assertEqual(global_stats.autotune.num_get_miss, 2) + self.assertEqual(global_stats.autotune.num_put, 2) class TestUtils(TestCase): diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index d026f7a3ba0f45..bccdacab2a679a 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] +import contextlib import sys import unittest @@ -19,7 +20,10 @@ try: from .test_torchinductor import check_model, check_model_cuda except ImportError: - from test_torchinductor import check_model, check_model_cuda + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": @@ -36,12 +40,20 @@ class ComboKernelTests(TestCase): def setUp(self): super().setUp() torch._inductor.metrics.reset() - torch._inductor.config.combo_kernels = True - torch._inductor.config.benchmark_combo_kernel = False + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": False, + } + ) + ) def tearDown(self): - super().tearDown() + self._test_stack.close() torch._inductor.metrics.reset() + super().tearDown() @requires_cuda def test_activation_functions(self): @@ -157,12 +169,20 @@ class ComboKernelBenchmarkTests(TestCase): def setUp(self): super().setUp() torch._inductor.metrics.reset() - torch._inductor.config.combo_kernels = True - torch._inductor.config.benchmark_combo_kernel = True + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": True, + } + ) + ) def tearDown(self): - super().tearDown() + self._test_stack.close() torch._inductor.metrics.reset() + super().tearDown() @requires_cuda def test_activation_benchmark(self): @@ -227,6 +247,30 @@ def test_mutated(a, b, c, d): out_eager = test_mutated(*inps) out_compiled = torch.compile(test_mutated)(*inps) + self.assertEqual(out_eager, out_compiled) + self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) + + @requires_cuda + def test_round_robin_dispatch(self): + # combo kernel dispatch strategy: round robin + def test_mutated(a, b, c, d): + a.add_(1) + b.sigmoid_() + c = torch.add(c, 5) + d.tanh_() + + return a, b, c, d + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 5, device="cuda"), + torch.rand(10, 10, device="cuda"), + torch.rand(5, 18, device="cuda"), + ] + + out_eager = test_mutated(*inps) + out_compiled = torch.compile(test_mutated)(*inps) + self.assertEqual(out_eager, out_compiled) self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) @@ -252,6 +296,245 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) + @requires_cuda + def test_persistent_reduction_no_x_dim(self): + def fn(x, y): + return x.sum(1), y.sum(1) + + inps = ( + torch.rand(16, 256, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + + +@instantiate_parametrized_tests +class ComboKernelDynamicShapesTests(TestCase): + check_model_cuda = check_model_cuda + check_model_cpu = check_model + check_kernel_count = True + + def setUp(self): + super().setUp() + torch._inductor.metrics.reset() + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": True, + } + ) + ) + self._test_stack.enter_context( + torch._dynamo.config.patch( + { + "automatic_dynamic_shapes": False, + "assume_static_by_default": False, + } + ) + ) + + def tearDown(self): + self._test_stack.close() + torch._inductor.metrics.reset() + super().tearDown() + + @requires_cuda + def test_dynamic_shapes_activations(self): + def test_activations(a, b, c): + a1 = torch.nn.functional.relu(a) + b1 = torch.nn.functional.sigmoid(b) + c1 = torch.nn.functional.tanh(c) + return a1, b1, c1 + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 20, device="cuda"), + torch.rand(10, 10, device="cuda"), + ] + + out_eager = test_activations(*inps) + out_compiled = torch.compile(test_activations)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) + + @requires_cuda + def test_dynamic_shapes_2d_blocking(self): + def fn(a0, a1, a2, b0, b1, b2): + c0 = torch.add(a0, b0) + c1 = torch.add(a1, b1) + c2 = torch.add(a2, b2) + return c0, c1, c2 + + self.check_model_cuda( + fn, + ( + torch.rand(30, 20, device="cuda"), + torch.rand(40, 30, device="cuda"), + torch.rand(36, 40, device="cuda"), + torch.rand(30, 20, device="cuda"), + torch.rand(30, 40, device="cuda").t(), + torch.rand(40, 36, device="cuda").t(), + ), + ) + + self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) + + @requires_cuda + def test_dynamic_shapes_reduce(self): + def test_reduce(a, b, c, d): + a1 = torch.sum(a, dim=0) + b1 = torch.max(b, dim=0) + c1 = torch.min(c, dim=0) + d1 = torch.nn.functional.tanh(d) + + return a1, b1, c1, d1 + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 20, device="cuda"), + torch.rand(10, 10, device="cuda"), + torch.rand(30, 8, device="cuda"), + ] + + out_eager = test_reduce(*inps) + out_compiled = torch.compile(test_reduce)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) + + @requires_cuda + def test_dynamic_shapes_mutated(self): + # combo kernel dispatch strategy: round robin + def test_mutated(a, b, c, d): + a.add_(1) + b.sigmoid_() + c = torch.add(c, 5) + d.tanh_() + + return a, b, c, d + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 5, device="cuda"), + torch.rand(10, 10, device="cuda"), + torch.rand(5, 18, device="cuda"), + ] + + out_eager = test_mutated(*inps) + out_compiled = torch.compile(test_mutated)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6) + + @requires_cuda + @torch._inductor.config.patch("combo_kernels_autotune", 0) + def test_dynamic_shapes_activations_no_autotune(self): + def test_activations(a, b, c): + a1 = torch.nn.functional.relu(a) + b1 = torch.nn.functional.sigmoid(b) + c1 = torch.nn.functional.tanh(c) + return a1, b1, c1 + + inps = [ + torch.rand(10, 10, device="cuda"), + torch.rand(20, 20, device="cuda"), + torch.rand(10, 10, device="cuda"), + ] + + out_eager = test_activations(*inps) + out_compiled = torch.compile(test_activations)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + def test_dynamic_shapes_persistent_reduction_no_x_dim(self): + def fn(x, y): + return x.sum(1), y.sum(1) + + inps = ( + torch.rand(16, 256, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + def test_dynamic_shapes_2d_blocking_round_robin(self): + def fn(a0, a1, a2, b0, b1, b2): + c0 = torch.add(a0, b0) + c1 = torch.add(a1, b1) + c2 = torch.add(a2, b2) + return c0, c1, c2 + + inps = ( + torch.rand(20, 30, device="cuda"), + torch.rand(30, 30, device="cuda"), + torch.rand(40, 32, device="cuda"), + torch.rand(30, 20, device="cuda").t(), + torch.rand(30, 30, device="cuda").t(), + torch.rand(32, 40, device="cuda").t(), + ) + + out_eager = fn(*inps) + compiled = torch.compile(fn) + out_compiled = compiled(*inps) + self.assertEqual(out_eager, out_compiled) + self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) + torch._inductor.metrics.reset() + + inps = ( + torch.rand(24, 30, device="cuda"), + torch.rand(32, 30, device="cuda"), + torch.rand(48, 32, device="cuda"), + torch.rand(30, 24, device="cuda").t(), + torch.rand(30, 32, device="cuda").t(), + torch.rand(32, 48, device="cuda").t(), + ) + out_compiled = compiled(*inps) + out_eager = fn(*inps) + self.assertEqual(out_eager, out_compiled) + self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + @torch._inductor.config.patch("triton.autotune_at_compile_time", True) + def test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda(self): + def fn(x, y, z): + return x.sum(1), y.mean(1), z.max(1) + + inps = ( + torch.rand(16, 128, device="cuda"), + torch.rand(32, 128, device="cuda"), + torch.rand(32, 256, device="cuda"), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[2], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index f1abcb9677f29c..cbe6a268c52275 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -19,23 +19,31 @@ import torch.nn.functional as F from torch import _inductor as inductor from torch._dynamo import compiled_autograd, config +from torch._dynamo.backends.debugging import aot_eager +from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.common_utils import skipIfWindows +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string # note: these tests are not run on windows due to inductor_utils.HAS_CPU -def make_compiler_fn(fullgraph=True, dynamic=True): +def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"): + assert backend in ["inductor", "aot_eager"] + def _compiler_fn(gm): """Same as torch.compile() but counts number of compiles""" def _inner_compiler(gm_, example_inputs_): counters["compiled_autograd"]["compiles"] += 1 - return inductor.compile(gm_, example_inputs_) + if backend == "inductor": + return inductor.compile(gm_, example_inputs_) + elif backend == "aot_eager": + return aot_eager(gm_, example_inputs_) return torch.compile( gm, backend=_inner_compiler, fullgraph=fullgraph, dynamic=dynamic @@ -171,6 +179,26 @@ def fn(): self.check_output_and_recompiles(fn) + def test_graph_break_custom_op(self): + @torch.library.custom_op("mylib::sin", mutates_args={}) + def sin(x: torch.Tensor) -> torch.Tensor: + return x.sin() + + def setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad): + (x,) = ctx.saved_tensors + return grad * x.cos() + + sin.register_autograd(backward, setup_context=setup_context) + + x = torch.randn(3, requires_grad=True) + y = sin(x.clone()).sum() + with compiled_autograd.enable(compiler_fn): + y.backward() + def test_tensor_grad_hook1(self): def fn(): for _ in range(3): @@ -783,9 +811,9 @@ def fn(): self.check_output_and_recompiles(fn, count=1) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_issue106555(self): - DEVICE = torch.device("cuda:0") + DEVICE = torch.device(GPU_TYPE, 0) NUM_FEATURES = 256 def bias_sigmoid_mul(x1, x2, bias): @@ -828,7 +856,8 @@ def _forward(self, x): x = x + self.module_with_jit_2(x.transpose(-2, -3)).transpose(-2, -3) return x - torch.cuda.set_device(device=DEVICE) + device_interface = get_interface_for_device(GPU_TYPE) + device_interface.set_device(device=DEVICE) torch.manual_seed(1234567890) model = Model() model.train() @@ -1048,7 +1077,7 @@ def backward(ctx, gO_1, gO_2, gO_3): self.check_output_and_recompiles(fn, count=2) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_logging_tensor_flaky(self) -> None: # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore # resulting in: @@ -1061,7 +1090,7 @@ def _fn(x): return x x = torch.arange( - 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" + 1, 10, requires_grad=True, dtype=torch.float16, device=GPU_TYPE ) out = _fn(x) loss = out.sum() @@ -1094,7 +1123,7 @@ def forward(inputs): compiled_fn(inputs) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_custom_fn_output_metadata(self): def my_compiler_fn(gm): for node in gm.graph.nodes: @@ -1122,7 +1151,7 @@ def backward(ctx, gO): return gO x = torch.arange( - 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" + 1, 10, requires_grad=True, dtype=torch.float16, device=GPU_TYPE ) x_view = x.view(3, 3) out = MyFn.apply(x_view) @@ -1410,6 +1439,179 @@ def _compiler_fn(gm): f, compiler_fn=compiler_fn_with_op_check, compile_fn=False ) + def test_trace_auto_functionalized(self): + torch.library.define( + "testlib::foo", + "(Tensor(a!) x) -> (Tensor)", + tags=torch.Tag.pt2_compliant_tag, + ) + torch.library.define( + "testlib::foo_mutated", + "(Tensor(a!) x) -> (Tensor)", + tags=torch.Tag.pt2_compliant_tag, + ) + + @torch.library.impl("testlib::foo", "cpu") + def foo(x): + x.add_(5) + return x + + @torch.library.impl("testlib::foo", "Meta") + def foo_meta(x): + return x + + @torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd") + def foo_mutated(x): + return torch.ops.testlib.foo(x) + + def _get_custom_policy(must_recompute_list=None): + def _custom_policy(ctx, func, *args, **kwargs): + if must_recompute_list is not None and func in must_recompute_list: + return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE + else: + return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + def context_fn(): + must_recompute_list = [ + torch.ops.higher_order.auto_functionalized, + ] + return torch.utils.checkpoint.create_selective_checkpoint_contexts( + _get_custom_policy( + must_recompute_list=must_recompute_list, + ), + ) + + def g(x): + x = torch.matmul(x, x) + torch.ops.testlib.foo_mutated(x) + return torch.matmul(x, x) + + def g_cp(x): + return torch.utils.checkpoint.checkpoint( + g, x, use_reentrant=False, context_fn=context_fn + ) + + def f(): + inps = (torch.randn(4, 4, requires_grad=True),) + output = torch.compile(g_cp, backend="aot_eager", fullgraph=True)(*inps) + output.sum().backward() + return output, inps[0].grad + + """ + Walkthrough of what happens with `auto_functionalized`: + 1. `auto_functionalized` op is inserted into the graph during AOTAutograd functionalization. + We force the op to be recomputed (by using SAC), so it appears in the backward graph. + 2. The AOT backward graph looks like: + ``` + ===== Backward graph 0 ===== + def forward(self, primals_1: "f32[4, 4][4, 1]cpu", tangents_1: "f32[4, 4][4, 1]cpu"): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm) + ... + return (add_1,) + ``` + 3. The Compiled Autograd graph looks like: + ``` + ===== Compiled autograd graph ===== + def forward(self, inputs, sizes, scalars, hooks): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm) + ... + return [] + ``` + 4. The Dynamo graph captured by Compiled Autograd looks like: + ``` + ===== __compiled_fn_3 ===== + def forward(self, L_inputs_ : list): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm) + ... + return (new_grad,) + ``` + 5. The Compiled Autograd's AOT "forward-only" graph looks like: + ``` + ===== Forward graph 1 ===== + def forward(self, arg0_1: "f32[][]cpu", arg1_1: "f32[4, 4][4, 1]cpu"): + ... + X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm) + ... + return (clone_1,) + ``` + 6. The `auto_functionalized` op should then be lowered using the normal lowering path in Inductor. + """ + + compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager") + + def make_compiler_fn_with_op_check(): + def _compiler_fn(gm): + # Checks that `auto_functionalized` op exists in Compiled Autograd's Dynamo graph. + self.assertTrue( + any( + node.target is torch.ops.higher_order.auto_functionalized + for node in gm.graph.nodes + ), + f"`torch.ops.higher_order.auto_functionalized` op not found in {gm.graph}", + ) + return compiler_fn(gm) + + return _compiler_fn + + compiler_fn_with_op_check = make_compiler_fn_with_op_check() + self.check_output_and_recompiles( + f, compiler_fn=compiler_fn_with_op_check, compile_fn=False + ) + + def test_non_traceable_autograd_cpp_node(self): + cpp_source = """ +struct CustomOpAutogradFunction : public torch::autograd::Function { + static constexpr bool is_traceable = false; + + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& x) { + return x; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext *ctx, + torch::autograd::variable_list grad_output) { + return grad_output; + } +}; + +torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { + return CustomOpAutogradFunction::apply(x); +} + +TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) { + m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); +} + """ + + module = torch.utils.cpp_extension.load_inline( + name="test_non_traceable_autograd_cpp_node", + cpp_sources=cpp_source, + functions="custom_op_backed_by_autograd_fn", + verbose=True, + ) + + def fn(): + x = torch.ones(10, 10, requires_grad=True) + out = torch.ops.test_non_traceable_autograd_cpp_node.custom_op_backed_by_autograd_fn( + x + ) + loss = out.sum() + loss.backward() + + with self.assertRaisesRegex( + RuntimeError, + "https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/", + ), compiled_autograd.enable(compiler_fn): + fn() + + @unittest.skip("Flaky, cache from test ordering affects test. #135369") def test_autograd_cpp_node(self): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { @@ -1891,17 +2093,20 @@ def fn(): self.check_output_and_recompiles(fn, 3) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_free_activation_memory(self): script = """ import torch +from torch._dynamo.device_interface import get_interface_for_device +from torch.testing._internal.inductor_utils import GPU_TYPE def main(): - assert(torch.cuda.memory_allocated() == 0) + device_interface = get_interface_for_device(GPU_TYPE) + assert(device_interface.memory_allocated() == 0) # Use an op to check that the memory is freed by the time the op is executed def assertion_impl(to_clone): - mem_allocated = torch.cuda.memory_allocated() + mem_allocated = device_interface.memory_allocated() assert mem_allocated < 4000000 # some activations should be freed return to_clone.clone() @@ -1924,8 +2129,8 @@ def forward(activations): compiled_fn = torch.compile(gm) # allocate at least 4,000,000 bytes (1,000,000 * 4 bytes) - activations = [torch.ones(1000000, dtype=torch.float32, device="cuda")] - assert torch.cuda.memory_allocated() > 4000000 + activations = [torch.ones(1000000, dtype=torch.float32, device=GPU_TYPE)] + assert device_interface.memory_allocated() > 4000000 out = compiled_fn(activations) assert len(activations) == 0 @@ -1934,19 +2139,22 @@ def forward(activations): """ self.run_as_subprocess(script) - @unittest.skipIf(not HAS_CUDA, "requires cuda") + @unittest.skipIf(not HAS_GPU, "requires gpu") def test_free_activation_memory_subclass(self): # cover the case when aot inputs have subclasses, resulting in a different runtime wrapper script = """ import torch +from torch._dynamo.device_interface import get_interface_for_device +from torch.testing._internal.inductor_utils import GPU_TYPE def main(): - assert torch.cuda.memory_allocated() == 0 + device_interface = get_interface_for_device(GPU_TYPE) + assert device_interface.memory_allocated() == 0 # Use an op to check that the memory is freed by the time the op is executed def assertion_impl(to_clone): - mem_allocated = torch.cuda.memory_allocated() + mem_allocated = device_interface.memory_allocated() assert mem_allocated < 1200000 # some activations should be freed assert mem_allocated > 800000 # currently subclasses don't seem to be freed in inductor return to_clone.clone() @@ -1974,23 +2182,24 @@ def fn(inputs): activations = [ jagged_from_list( [ - torch.ones((1, 100000), device="cuda"), # 400,000 bytes - torch.ones((1, 100000), device="cuda"), # 400,000 bytes + torch.ones((1, 100000), device=GPU_TYPE), # 400,000 bytes + torch.ones((1, 100000), device=GPU_TYPE), # 400,000 bytes ], None, )[ 0 ], # NestedTensor - torch.ones((1, 100000), device="cuda"), # 400,000 bytes + torch.ones((1, 100000), device=GPU_TYPE), # 400,000 bytes ] # 1,200,000 bytes (3 * 4 * 100,000 bytes) - assert torch.cuda.memory_allocated() > 1200000 + assert device_interface.memory_allocated() > 1200000 out = compiled_fn(activations) assert len(activations) == 0 main() """ + self.run_as_subprocess(script) def test_callback_graph_break_throws_error(self): called = [0] @@ -2341,6 +2550,7 @@ def fn(x, obj): sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) ) + @skipIfWindows(msg="AssertionError: Scalars are not equal!") def test_verbose_logs_cpp(self): torch._logging.set_logs(compiled_autograd_verbose=True) @@ -2468,6 +2678,20 @@ def tensor_hook(_): self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 1) + def test_reentrant_checkpointing(self): + def fn(x): + y = x.sin() + z = y.cos() + return (y * z).sum() + + inp = torch.rand(10, 10, requires_grad=True) + out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) + with self.assertRaisesRegex( + RuntimeError, + r"\(e.g. reentrant checkpointing\), this is not supported yet\.", + ), torch._dynamo.compiled_autograd.enable(torch.compile): + out.backward() + def load_test_module(name): testdir = Path(__file__).absolute().parent.parent @@ -2527,7 +2751,7 @@ def wrap_test_class(orig_cls): "test_tensor_hooks_inplace_multiple_outputs", # uses assert in hook "test_hooks", # uses assert in hook "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose - "test_save_tensor_hook_version_counter_not_shared", # assertEqual + "test_saved_tensors_hook_version_counter_not_shared", # assertEqual "test_post_accumulate_grad_hook_returns_not_None", # throws "test_custom_function_cycle", # assertEqual "test_mark_non_differentiable_mixed", # assertTrue @@ -2594,7 +2818,8 @@ def wrap_test_class(orig_cls): "test_custom_function_forward_mode_inplace_checks", # forward AD "test_custom_function_forward_mode_view_checks", # forward AD "test_custom_function_forward_mode_wrong_formula", # forward AD - "test_default_saved_variable_hooks_double_backward", # create_graph + "test_default_saved_tensors_hooks_double_backward", # create_graph + "test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook' "test_full_backward_hook_double_backward", # create_graph "test_function", # create_graph "test_grad", # create_graph @@ -2651,7 +2876,7 @@ def wrap_test_class(orig_cls): "test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173 # Category: FakeTensor "test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph - "test_wrapped_number_saved_variable_hooks", # Proxy tensor should carryover is_wrapped_number_ of its original + "test_wrapped_number_saved_tensors_hooks", # Proxy tensor should carryover is_wrapped_number_ of its original "test_grad_batched_grad", # torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads # Category: Divergence from eager diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index a1a38a95b9aca4..b7fde0bb9fa9fc 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -146,6 +146,9 @@ class KernelCounts(NamedTuple): "test_sgd_weight_decay_maximize_cuda": 4, "test_sgd_weight_decay_maximize_xpu": 4, "test_sgd_weight_decay_maximize_cpu": 4, + "test_sgd_weight_decay_cpu": 4, + "test_sgd_weight_decay_cuda": 4, + "test_sgd_weight_decay_xpu": 4, "test_sgd_momentum_weight_decay_foreach_cuda": 2, "test_sgd_momentum_weight_decay_foreach_xpu": 2, "test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2, @@ -214,7 +217,7 @@ def build_opt_kwarg_db(): has_tensor_lr = False for key, val in kwargs.items(): - if not key == "lr" and ( + if (not key == "lr" and not key == "betas") and ( not isinstance(val, bool) or (isinstance(val, bool) and val) ): name += "_" + key @@ -223,6 +226,9 @@ def build_opt_kwarg_db(): has_tensor_lr = True name += "_tensor_lr" + if key == "betas" and isinstance(kwargs["betas"][0], torch.Tensor): + name += "_tensor_betas" + name += f"_{device}" kwargs["device"] = device @@ -264,7 +270,10 @@ def build_opt_kwarg_db(): try: from .test_torchinductor import check_model, check_model_gpu except ImportError: - from test_torchinductor import check_model, check_model_gpu + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_gpu, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": @@ -364,6 +373,16 @@ def test_fn(self): kwargs["lr"] = kwargs["lr"].to(device) kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device) + if "betas" in kwargs and isinstance(kwargs["betas"][0], torch.Tensor): + kwargs["betas"] = ( + kwargs["betas"][0].to(device), + kwargs["betas"][1].to(device), + ) + kwargs_compiled["betas"] = ( + kwargs_compiled["betas"][0].to(device), + kwargs_compiled["betas"][1].to(device), + ) + torch._dynamo.reset() torch._inductor.metrics.reset() input = torch.ones([10, 10], device=device) @@ -819,7 +838,7 @@ def test_S429861(self): try: from . import s429861_repro except ImportError: - import s429861_repro + import s429861_repro # @manual forward = s429861_repro.forward diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index ced9637bbb5310..d4580a9956e8e7 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -1,11 +1,13 @@ # Owner(s): ["module: inductor"] import itertools +import unittest import torch import torch._dynamo.testing from torch._higher_order_ops.associative_scan import associative_scan from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import ( + decorateIf, instantiate_parametrized_tests, parametrize, ) @@ -698,9 +700,15 @@ def test_while_loop_with_outer_buffers(self, device, dynamic): class AssociativeScanTests(TestCase): @requires_gpu - @parametrize("device", [GPU_TYPE]) + @parametrize("combine_mode", ["pointwise", "generic"]) @parametrize("backend", ["inductor"]) - def test_pointwise_associative_scan_CUDA_flip(self, device, backend): + @parametrize("device", [torch.device("cpu"), GPU_TYPE]) + # This test will fail as flip in combination with particular input lenghts + # produces weird results. + # This is under investigations in + # https://github.com/pytorch/pytorch/issues/131805 + @decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE) + def test_associative_scan_CUDA_flip(self, combine_mode, backend, device): def fct(x: torch.Tensor, y: torch.Tensor): return x + y @@ -712,17 +720,35 @@ def fct(x: torch.Tensor, y: torch.Tensor): ) associative_scan2 = associative_scan - result1 = associative_scan1(fct, x, 0, reverse=False) - result2 = associative_scan2(fct, x, 0, reverse=False) + if combine_mode == "pointwise" and device == torch.device("cpu"): + with self.assertRaisesRegex(Exception, r"."): + associative_scan1( + fct, x, 0, reverse=False, combine_mode=combine_mode + ) + + # Skipping test because combine_mode currently only suppors CUDA tensors + return + + result1 = associative_scan1( + fct, x, 0, reverse=False, combine_mode=combine_mode + ) + result2 = associative_scan2( + fct, x, 0, reverse=False, combine_mode=combine_mode + ) result3 = torch.cumsum(x, 0) self.assertEqual(result1, result2) self.assertEqual(result1, result3) # Flip only non-compiled and compare with compiled reverse=True - result1 = associative_scan1(fct, x, 0, reverse=True) + result1 = associative_scan1( + fct, x, 0, reverse=True, combine_mode=combine_mode + ) result2 = torch.flip( - associative_scan2(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan2( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], ) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) @@ -731,9 +757,14 @@ def fct(x: torch.Tensor, y: torch.Tensor): # Flip only compiled and compare with non-compiled reverse=True result1 = torch.flip( - associative_scan1(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan1( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], + ) + result2 = associative_scan2( + fct, x, 0, reverse=True, combine_mode=combine_mode ) - result2 = associative_scan2(fct, x, 0, reverse=True) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) self.assertEqual(result1, result2) @@ -741,10 +772,16 @@ def fct(x: torch.Tensor, y: torch.Tensor): # Use reverse=False, but flip both results before and after result1 = torch.flip( - associative_scan1(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan1( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], ) result2 = torch.flip( - associative_scan2(fct, torch.flip(x, [0]), 0, reverse=False), [0] + associative_scan2( + fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode + ), + [0], ) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) @@ -752,8 +789,12 @@ def fct(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result1, result3) # Reverse=True - result1 = associative_scan1(fct, x, 0, reverse=True) - result2 = associative_scan2(fct, x, 0, reverse=True) + result1 = associative_scan1( + fct, x, 0, reverse=True, combine_mode=combine_mode + ) + result2 = associative_scan2( + fct, x, 0, reverse=True, combine_mode=combine_mode + ) result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) self.assertEqual(result1, result2) diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py index dbe92859a08d66..bedb9a64727c7c 100644 --- a/test/inductor/test_coordinate_descent_tuner.py +++ b/test/inductor/test_coordinate_descent_tuner.py @@ -12,7 +12,7 @@ try: - import triton + import triton # @manual except ImportError: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_cpp_wrapper_hipify.py b/test/inductor/test_cpp_wrapper_hipify.py index 43b1909e927df8..62f23ad3abc768 100644 --- a/test/inductor/test_cpp_wrapper_hipify.py +++ b/test/inductor/test_cpp_wrapper_hipify.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] import torch from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper -from torch._inductor.codegen.codegen_device_driver import cuda_kernel_driver +from torch._inductor.codegen.common import get_device_op_overrides from torch._inductor.test_case import run_tests, TestCase @@ -34,7 +34,8 @@ def test_hipify_basic_declaration(self) -> None: self.assertEqual(result, expected) def test_hipify_aoti_driver_header(self) -> None: - header = cuda_kernel_driver() + cuda_codegen = get_device_op_overrides("cuda") + header = cuda_codegen.kernel_driver() expected = """ #define CUDA_DRIVER_CHECK(EXPR) \\ do { \\ diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 0d4d7cb6a1dbe2..91ecebd5ca11ec 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -28,11 +28,11 @@ test_torchinductor_dynamic_shapes, ) except ImportError: - import test_cpu_repro - import test_cpu_select_algorithm - import test_mkldnn_pattern_matcher - import test_torchinductor - import test_torchinductor_dynamic_shapes + import test_cpu_repro # @manual=fbcode//caffe2/test/inductor:test_cpu_repro-library + import test_cpu_select_algorithm # @manual=fbcode//caffe2/test/inductor:cpu_select_algorithm_cpu-library + import test_mkldnn_pattern_matcher # @manual + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library + import test_torchinductor_dynamic_shapes # @manual=fbcode//caffe2/test/inductor:test_inductor-library_dynamic_shapes except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) @@ -87,30 +87,7 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): } ) if config.abi_compatible: - xfail_list = [ - "test_conv2d_binary_inplace_fusion_failed_cpu", - "test_conv2d_binary_inplace_fusion_pass_cpu", - "test_dynamic_qlinear_cpu", - "test_dynamic_qlinear_qat_cpu", - "test_lstm_packed_change_input_sizes_cpu", - "test_qconv2d_add_cpu", - "test_qconv2d_add_relu_cpu", - "test_qconv2d_cpu", - "test_qconv2d_dequant_promotion_cpu", - "test_qconv2d_maxpool2d_linear_dynamic_cpu", - "test_qconv2d_relu_cpu", - "test_qlinear_cpu", - "test_qlinear_add_cpu", - "test_qlinear_add_relu_cpu", - "test_qlinear_dequant_promotion_cpu", - "test_qlinear_gelu_cpu", - "test_qlinear_relu_cpu", - *[ - func - for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) - if func.startswith("test_linear_with_pointwise") - ], - ] + xfail_list = [] for test_name in xfail_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( ("cpp_wrapper",), is_skip=False @@ -119,7 +96,11 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): f"{test_name}_dynamic_shapes" ] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False) skip_list = [ - "test_multihead_attention_cpu", + *[ + func + for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) + if func.startswith("test_linear_with_pointwise") + ], ] for test_name in skip_list: test_failures_cpp_wrapper[test_name] = test_torchinductor.TestFailure( @@ -147,7 +128,7 @@ def make_test_case( assert callable(func), "not a callable" func = slowTest(func) if slow else func - @config.patch(cpp_wrapper=True) + @config.patch(cpp_wrapper=True, search_autotune_cache=False) def fn(self): tests.setUpClass() tests.setUp() @@ -209,8 +190,12 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), func_inputs=[ - ["op_mkldnn__convolution_pointwise_binary.call"], - ["op_mkldnn__convolution_pointwise__binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise_binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise__binary.call"], ], ), BaseTest( @@ -219,8 +204,12 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), func_inputs=[ - ["op_mkldnn__convolution_pointwise__binary.call"], - ["op_mkldnn__convolution_pointwise_binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise__binary.call"], + None + if config.abi_compatible + else ["op_mkldnn__convolution_pointwise_binary.call"], ], ), BaseTest( @@ -319,11 +308,13 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestDynamicPatternMatcher(), condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, func_inputs=[ - [ + None + if config.abi_compatible + else [ "op_onednn_qconv2d_pointwise_.call", "op_quantized_max_pool2d_.call", "op_onednn_qlinear_pointwise_tensor.call", - ] + ], ], ), BaseTest( diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 8c2691aaac4479..bc65658a1e2ec1 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -10,30 +10,20 @@ from typing import Callable from unittest.mock import patch -import numpy as np -import sympy - import torch from torch import nn from torch._C import FileCheck from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same from torch._inductor import config, cpu_vec_isa, metrics, test_operators -from torch._inductor.codegen.common import OptimizationContext -from torch._inductor.codegen.cpp import ( - CppOverrides, - CppVecKernelChecker, - CppVecOverrides, -) +from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides from torch._inductor.compile_fx import ( compile_fx, compile_fx_inner, complex_memory_overlap, ) from torch._inductor.graph import GraphLowering -from torch._inductor.ir import InterpreterShim from torch._inductor.utils import timed -from torch._inductor.virtualized import V from torch._prims_common import is_float_dtype from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F @@ -52,7 +42,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) @@ -1917,6 +1907,16 @@ def test_bitwise_right_shift(self): res = cfn(x, bit_num) self.assertEqual(res_aten_eager, res) + def test_view_dtype(self): + def f(x): + return x.view(torch.int32) >> 2 + + input = torch.ones(16, 16) + res_aten_eager = f(input) + cfn = torch.compile(f) + res = cfn(input) + self.assertEqual(res_aten_eager, res) + @patch("torch.cuda.is_available", lambda: False) def test_scatter_using_atomic_add(self): def fn(a, dim, index, b): @@ -1972,6 +1972,32 @@ def _internal_check( with config.patch({"cpp.dynamic_threads": True}), set_num_threads(1): _internal_check(fn, inps, "aten.scatter_reduce_") + @patch("torch.cuda.is_available", lambda: False) + @requires_vectorization + @torch._inductor.config.patch({"cpp.fallback_scatter_reduce_sum": False}) + def test_scatter_using_atomic_add_vec(self): + def fn(a, dim, index, b): + return aten.scatter(a, dim, index, b, reduce="add") + + inps = ( + torch.zeros(1, 1, 25), + 2, + torch.tensor([[[3, 5, 7, 9] * 5]]), + torch.ones(1, 1, 25), + ) + torch._dynamo.reset() + metrics.reset() + self.common(fn, inps) + assert metrics.generated_cpp_vec_kernel_count == 2 + + with set_num_threads(1), config.patch( + {"fx_graph_cache": False, "fx_graph_remote_cache": False} + ): + torch._dynamo.reset() + metrics.reset() + self.common(fn, inps) + assert metrics.generated_cpp_vec_kernel_count == 2 + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @requires_vectorization @patch("torch.cuda.is_available", lambda: False) @@ -2349,230 +2375,6 @@ def fn(x): assert metrics.cpp_to_dtype_count == 2 check_metrics_vec_kernel_count(1) - @requires_vectorization - @patch("torch.cuda.is_available", lambda: False) - def test_cpp_vec_constant_checker(self): - _graph: torch.fx.Graph = torch.fx.Graph() - a: torch.fx.Node = _graph.create_node("placeholder", "ops") - iv: torch.fx.Node = _graph.create_node("placeholder", "iv") - fv: torch.fx.Node = _graph.create_node("placeholder", "fv") - b: torch.fx.Node = _graph.create_node( - "call_method", - "constant", - args=( - a, - iv, - torch.int64, - ), - ) - c: torch.fx.Node = _graph.create_node( - "call_method", - "constant", - args=( - a, - fv, - torch.double, - ), - ) - d: torch.fx.Node = _graph.create_node( - "call_method", - "ge", - args=( - a, - b, - b, - ), - ) - _graph.output((d, c)) - - def get_index(): - return "" - - submodules = {"get_index": get_index} - - graph_lowering = GraphLowering( - torch.fx.GraphModule(submodules, _graph), - shape_env=None, - ) - - def set_opt_dtype(graph): - for node in graph.nodes: - if node.target == "constant": - if OptimizationContext.key in node.meta: - opt_ctx = node.meta[OptimizationContext.key] - else: - opt_ctx = OptimizationContext() - opt_ctx.dtype = node.args[-1] - node.meta[OptimizationContext.key] = opt_ctx - - with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( - graph_lowering - ): - # The moset inner loop variable is used in the index_expr - tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float) - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - i32_iinfo = np.iinfo(np.int32) - f32_iinfo = np.finfo(np.float32) - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, np.inf - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, -np.inf - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min - 1, f32_iinfo.min - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.max + 1, f32_iinfo.max - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.min, f32_iinfo.min * (1 + 1e-5) - ) - self.assertTrue(vec_checker.simd_vec) - - vec_checker.simd_vec = True - set_opt_dtype(_graph) - InterpreterShim(_graph, submodules).run( - V.get_ops_handler(), i32_iinfo.max, f32_iinfo.max * (1 + 1e-5) - ) - self.assertTrue(vec_checker.simd_vec) - - @requires_vectorization - @patch("torch.cuda.is_available", lambda: False) - def test_cpp_vec_index_expr_checker(self): - _graph: torch.fx.Graph = torch.fx.Graph() - a: torch.fx.Node = _graph.create_node("placeholder", "ops") - b: torch.fx.Node = _graph.create_node("call_module", "get_index", args=()) - c: torch.fx.Node = _graph.create_node( - "call_method", - "index_expr", - args=( - a, - b, - torch.int64, - ), - ) - d: torch.fx.Node = _graph.create_node( - "call_method", - "ge", - args=( - a, - c, - c, - ), - ) - _graph.output(d) - - def get_index(): - return "" - - submodules = {"get_index": get_index} - graph_lowering = GraphLowering( - torch.fx.GraphModule(submodules, _graph), - shape_env=None, - ) - with patch.object(graph_lowering, "wrapper_code", ""), V.set_graph_handler( - graph_lowering - ): - itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")] - - tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float) - # The most inner loop variable is used in the index_expr - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] - - ranges = [0, 100, 200] - vec_checker.itervars = itervars[:2] - vec_checker.ranges = ranges[:2] - submodules = {"get_index": get_index} - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - - # Most inner loop variable irrevalant - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return -itervars[0] ** 2 + 2 * itervars[0] + itervars[1] - - ranges = [0, 100, 200] - vec_checker.itervars = itervars - vec_checker.ranges = ranges - submodules = {"get_index": get_index} - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - - i32_iinfo = np.iinfo(np.int32) - _max_value = i32_iinfo.max + 1 - ranges = [_max_value, _max_value, _max_value] - # Most inner loop variable irrevalant but max value is greater than - # the max value of INT32 - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return itervars[0] - - submodules = {"get_index": get_index} - vec_checker.itervars = itervars - vec_checker.ranges = ranges - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - - # Most inner loop variable irrevalant but min value is greater than - # the min value of INT32 - with CppVecKernelChecker( - args=None, num_threads=1, tiling_factor=tiling_factor - ) as vec_checker: - - def get_index(): - return -itervars[0] - 2 - - submodules = {"get_index": get_index} - vec_checker.itervars = itervars - vec_checker.ranges = ranges - InterpreterShim(_graph, submodules).run(V.get_ops_handler()) - self.assertTrue(vec_checker.simd_vec) - @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_maxpool2d_cpu_only(self): @@ -3620,7 +3422,6 @@ def fn(): dtype if dtype else torch.float32, ) - @config.patch("cpp.enable_tiling_heuristics", False) def test_group_norm_vec(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -3643,6 +3444,18 @@ def forward(self, x): # 2 generated kernels (one for var_mean, the other for result) check_metrics_vec_kernel_count(2) + # check loop split optimization + if fmt == torch.channels_last: + torch._dynamo.reset() + metrics.reset() + with torch.no_grad(): + opt_mod = torch.compile(mod) + _, code = run_and_get_cpp_code(opt_mod, x) + # check that there are no non_contiguous loads + FileCheck().check_count("__at_align__ std::array", 0, exactly=True).run( + code + ) + def test_int_div_vec(self): def fn(x, y, mode): return torch.div(x, y, rounding_mode=mode) @@ -4252,6 +4065,68 @@ def fn(x1, x2, x3): n_veckernel = 6 if op is torch.masked.mean else 3 check_metrics_vec_kernel_count(n_veckernel) + @requires_vectorization + def test_full_bits_lowp(self): + def check_use_full_bits(func, shapes, dtype, mixed, check_vecn): + example_inputs = [torch.randn(shape, dtype=dtype) for shape in shapes] + if mixed: + example_inputs[0] = example_inputs[0].to( + dtype=torch.half if dtype == torch.bfloat16 else torch.bfloat16 + ) + f_opt = torch.compile()(func) + _, code = run_and_get_cpp_code(f_opt, *example_inputs) + if check_vecn: + self.assertTrue( + "at::vec::VectorizedN" in code or "at::vec::convert None: @@ -1262,6 +1280,27 @@ def forward(self, x): out2 = m(input_tensor) self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) + @config.patch("triton.cudagraphs", True) + def test_cpu_index(self): + @torch.compile(fullgraph=True) + def fn(x): + return x[torch.arange(32)] + + result, (graph,) = run_and_get_graph_lowering( + fn, torch.randn(64, device="cuda") + ) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + + inp = torch.randn(64, device="cuda", requires_grad=True) + result, (graph,) = run_and_get_graph_lowering(fn, inp) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + + result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward()) + self.assertEqual(graph.disable_cudagraphs_reason, None) + self.assertEqual(graph.device_types, {"cuda"}) + def test_reflection_pad_loop_order(self): def fn(x, y): a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect") diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index cfff3baa62765c..a2675ed3231dd8 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -2445,6 +2445,17 @@ def iter(batch_size: int, mod: torch.nn.Module): exactly=True, ).run("\n".join(captured_output)) + @torch._inductor.config.patch("cpp_wrapper", 1) + def test_cpp_wrapper(self): + def f(x): + return torch.sin(x) + + compiled = torch.compile(f, mode="reduce-overhead") + example_input = torch.randn(10, device="cuda") + compiled_result = self.run_twc(compiled, example_input) + eager_result = f(example_input) + self.assertEqual(compiled_result, eager_result) + instantiate_parametrized_tests(CudaGraphTreeTests) if __name__ == "__main__": diff --git a/test/inductor/test_cudagraph_trees_expandable_segments.py b/test/inductor/test_cudagraph_trees_expandable_segments.py index 9f57aa0d30a494..aa1e85fd82d149 100644 --- a/test/inductor/test_cudagraph_trees_expandable_segments.py +++ b/test/inductor/test_cudagraph_trees_expandable_segments.py @@ -18,12 +18,14 @@ try: from .test_cudagraph_trees import CudaGraphTreeTests except ImportError: - from test_cudagraph_trees import CudaGraphTreeTests # noqa: F401 + from test_cudagraph_trees import ( # noqa: F401 # @manual=fbcode//caffe2/test/inductor:cudagraph_trees-library + CudaGraphTreeTests, + ) REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(REPO_ROOT)) -from tools.stats.import_test_stats import get_disabled_tests +from tools.stats.import_test_stats import get_disabled_tests # @manual # Make sure to remove REPO_ROOT after import is done diff --git a/test/inductor/test_debug_trace.py b/test/inductor/test_debug_trace.py index 304cba001106fb..701d4e6cd9f5a5 100644 --- a/test/inductor/test_debug_trace.py +++ b/test/inductor/test_debug_trace.py @@ -18,7 +18,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 93798fc5aedf3e..68997635f3cfb9 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -1,7 +1,6 @@ # Owner(s): ["module: inductor"] import logging -import unittest import torch import torch._inductor @@ -13,15 +12,13 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.triton_utils import requires_gpu class MyModule(torch.nn.Module): def __init__( - self, n_input: int, n_output: int, has_bias: bool, device="cuda" + self, n_input: int, n_output: int, has_bias: bool, device=GPU_TYPE ) -> None: super().__init__() self.linear = torch.nn.Linear(n_input, n_output, bias=has_bias) @@ -48,7 +45,7 @@ def forward(self, input1, input2): return output -@requires_cuda +@requires_gpu @torch._inductor.config.patch( post_grad_fusion_options={ "decompose_mm_pass": {}, @@ -89,12 +86,12 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): ) def test_decompose_bmm(self, b, m, n, k, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - mat1 = torch.randn(b, m, k, device="cuda").requires_grad_(True) - mat2 = torch.randn(b, k, n, device="cuda").requires_grad_(True) + mat1 = torch.randn(b, m, k, device=GPU_TYPE).requires_grad_(True) + mat2 = torch.randn(b, k, n, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule2().to("cuda") + module = MyModule2().to(GPU_TYPE) traced = torch.compile(module) input = [mat1, mat2] ref = module(*input) @@ -102,7 +99,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -113,7 +110,7 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 3 if should_decompose else 0 + expected_val = 3 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_bmm"], expected_val, @@ -127,11 +124,11 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): @parametrize("has_bias", [True, False]) def test_decompose_linear(self, m, n, k, has_bias, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - input = torch.randn(m, k, device="cuda").requires_grad_(True) + input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule(k, n, has_bias).to("cuda") + module = MyModule(k, n, has_bias).to(GPU_TYPE) traced = torch.compile(module) input = [input] ref = module(*input) @@ -139,7 +136,7 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -172,13 +169,13 @@ def test_decompose_linear(self, m, n, k, has_bias, should_decompose): def test_decompose_linear_mixed_precision( self, m, n, k, has_bias, should_decompose ): - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16): torch._logging.set_logs(inductor=logging.DEBUG) - input = torch.randn(m, k, device="cuda").requires_grad_(True) + input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule(k, n, has_bias).to("cuda") + module = MyModule(k, n, has_bias).to(GPU_TYPE) traced = torch.compile(module) input = [input] ref = module(*input) @@ -186,7 +183,7 @@ def test_decompose_linear_mixed_precision( self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -218,12 +215,12 @@ def test_decompose_linear_mixed_precision( @parametrize("has_bias", [True, False]) def test_decompose_mm(self, m, n, k, has_bias, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - mat1 = torch.randn(m, k, device="cuda").requires_grad_(True) - mat2 = torch.randn(k, n, device="cuda").requires_grad_(True) + mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) + mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule3().to("cuda") + module = MyModule3().to(GPU_TYPE) traced = torch.compile(module) input = [mat1, mat2] ref = module(*input) @@ -231,7 +228,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -243,7 +240,7 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -256,14 +253,14 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): ) @parametrize("has_bias", [True, False]) def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose): - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.amp.autocast(device_type=GPU_TYPE, dtype=torch.bfloat16): torch._logging.set_logs(inductor=logging.DEBUG) - mat1 = torch.randn(m, k, device="cuda").requires_grad_(True) - mat2 = torch.randn(k, n, device="cuda").requires_grad_(True) + mat1 = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) + mat2 = torch.randn(k, n, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule3().to("cuda") + module = MyModule3().to(GPU_TYPE) traced = torch.compile(module) input = [mat1, mat2] ref = module(*input) @@ -271,7 +268,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"], expected_val, @@ -283,7 +280,7 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) self.compare_parameters(module, traced) self.compare_gradients(module, traced) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 self.assertEqual( counters["inductor"]["decompose_mm"] - decompose_mm_fwd, expected_val, @@ -294,11 +291,11 @@ def test_decompose_mm_mixed_precision(self, m, n, k, has_bias, should_decompose) @parametrize("has_bias", [True, False]) def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): torch._logging.set_logs(inductor=logging.DEBUG) - input = torch.randn(m, k, device="cuda").requires_grad_(True) + input = torch.randn(m, k, device=GPU_TYPE).requires_grad_(True) counters.clear() - module = MyModule(k, n, has_bias).to("cuda") + module = MyModule(k, n, has_bias).to(GPU_TYPE) traced = torch.compile(module, dynamic=True) input = [input] ref = module(*input) @@ -306,7 +303,7 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_pred(module, traced, input) - expected_val = 1 if should_decompose else 0 + expected_val = 1 if should_decompose and HAS_CUDA else 0 if has_bias: self.assertEqual( counters["inductor"]["decompose_addmm"], @@ -319,9 +316,13 @@ def test_dynamic_shape(self, m, n, k, has_bias, should_decompose): self.compare_parameters(module, traced) self.compare_gradients(module, traced) + expected_val = 0 + if HAS_CUDA: + expected_val = 1 if has_bias else 2 + self.assertEqual( counters["inductor"]["decompose_mm"], - 1 if has_bias else 2, + expected_val, ) counters.clear() @@ -330,8 +331,8 @@ def test_realize_input(self): k = 5 n = 2 torch._logging.set_logs(inductor=logging.DEBUG) - input1 = torch.randn(m, k, device="cuda").T.contiguous() - input2 = torch.randn(k, n, device="cuda") + input1 = torch.randn(m, k, device=GPU_TYPE).T.contiguous() + input2 = torch.randn(k, n, device=GPU_TYPE) @torch.compile() def foo(x, y): @@ -339,8 +340,12 @@ def foo(x, y): out, code = run_and_get_code(foo, input1, input2) - # two kernels generated - FileCheck().check_count(".run(", 2, exactly=True).run(code[0]) + if GPU_TYPE == "xpu": + # only 1 kernel generated on the XPU stack + FileCheck().check_count(".run(", 1, exactly=True).run(code[0]) + else: + # two kernels generated + FileCheck().check_count(".run(", 2, exactly=True).run(code[0]) if __name__ == "__main__": diff --git a/test/inductor/test_distributed_patterns.py b/test/inductor/test_distributed_patterns.py index a647a16c421da4..fd446370434e94 100644 --- a/test/inductor/test_distributed_patterns.py +++ b/test/inductor/test_distributed_patterns.py @@ -26,23 +26,14 @@ def reduce_scatter(t): return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone() def fw_pre_hook(mod, inp): - if not compiled_autograd.compiled_autograd_enabled: - # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. - mod.unsharded_weight.untyped_storage().resize_( - mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() - ) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) - else: - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - torch.ops.fsdp.set_( - mod.unsharded_weight, all_gather(mod.sharded_weight) - ) - mod.weight = mod.unsharded_weight + mod.unsharded_weight.untyped_storage().resize_( + mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() + ) + with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( + mod.unsharded_weight + ): + torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight)) + mod._parameters["weight"] = mod.unsharded_weight # Forward: # mod.sharded_weight = local_shard (always) @@ -54,27 +45,18 @@ def fw_pre_hook(mod, inp): # mod.unsharded_weight = zero-sized allgather def fw_post_hook(mod, inp, out): - mod.weight = mod.sharded_weight + mod._parameters["weight"] = mod.sharded_weight mod.unsharded_weight.untyped_storage().resize_(0) def bw_pre_hook(mod, gO): - if not compiled_autograd.compiled_autograd_enabled: - # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. - mod.unsharded_weight.untyped_storage().resize_( - mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() - ) - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) - else: - with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( - mod.unsharded_weight - ): - torch.ops.fsdp.set_( - mod.unsharded_weight, all_gather(mod.sharded_weight) - ) - mod.weight = mod.unsharded_weight + mod.unsharded_weight.untyped_storage().resize_( + mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() + ) + with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( + mod.unsharded_weight + ): + torch.ops.fsdp.copy_(mod.unsharded_weight, all_gather(mod.sharded_weight)) + mod._parameters["weight"] = mod.unsharded_weight # Backward: # mod.sharded_weight = local_shard (always) @@ -88,7 +70,7 @@ def bw_pre_hook(mod, gO): def bw_post_hook(mod, gI, gO): grad = mod.weight.grad new_grad = reduce_scatter(grad) - mod.weight = mod.sharded_weight + mod._parameters["weight"] = mod.sharded_weight mod.weight.grad = new_grad mod.unsharded_weight.untyped_storage().resize_(0) @@ -99,7 +81,6 @@ def bw_post_hook(mod, gI, gO): m.sharded_weight = nn.Parameter(reduce_scatter(m.weight)) m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight)) m.unsharded_weight.untyped_storage().resize_(0) - del m.weight m.register_full_backward_pre_hook(bw_pre_hook) m.register_full_backward_hook(bw_post_hook) @@ -466,7 +447,6 @@ def test_fake_distributed_aot_eager(self): @requires_gpu() @torch._functorch.config.patch(recompute_views=True) def test_fake_distributed_inductor(self): - # TODO: fix .set_ lowering in CPU inductor, and enable the CPU test. m1, inp1 = init_fake_distributed(GPU_TYPE) out1 = steps(m1, inp1) diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index 55633677661753..90628a4c6a135f 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -4,7 +4,6 @@ import itertools import os import sys -import unittest import torch from torch import nn @@ -17,22 +16,16 @@ from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN +from torch.testing._internal.common_utils import TEST_WITH_ASAN from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - importlib.import_module("functorch") importlib.import_module("filelock") -from inductor.test_torchinductor import copy_tests +from inductor.test_torchinductor import ( + copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library +) class ConvOp(nn.Module): diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 22fea818b56c46..6f972e46a1d987 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -11,7 +11,7 @@ try: - from extension_backends.cpp.extension_codegen_backend import ( + from extension_backends.cpp.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950 ExtensionCppWrapperCodegen, ExtensionScheduling, ExtensionWrapperCodegen, @@ -38,7 +38,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a657a371081415..1cb6354275969e 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -3,9 +3,10 @@ import functools import string +import unittest from collections import namedtuple -from contextlib import nullcontext -from typing import Callable, Optional +from contextlib import contextmanager, nullcontext +from typing import Callable, Optional, Tuple from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -28,19 +29,27 @@ ) from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU +from torch.testing._internal.common_utils import TEST_WITH_ROCM from torch.utils._triton import has_triton # Skip tests if Triton is not available supported_platform = skipUnless( torch.cuda.is_available() - and torch.version.hip is None and has_triton() and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) +# Use this decorator only when hitting Triton bugs on H100 +running_on_a100_or_rocm_only = skipUnless( + torch.cuda.is_available() + and has_triton() + and (torch.cuda.get_device_capability() == (8, 0) or torch.version.hip is not None), + "Requires (A100 or ROCm) and Triton", +) + Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) torch.set_float32_matmul_precision("high") @@ -48,6 +57,22 @@ Tensor = torch.Tensor +@contextmanager +def temp_float32_matmul_precision(precision: str): + """ + Temporarily set the float32 matmul precision and restore it after the context is exited. + + Args: + precision (str): The precision to set ('highest', 'high', or 'medium'). + """ + original_precision = torch.get_float32_matmul_precision() + try: + torch.set_float32_matmul_precision(precision) + yield + finally: + torch.set_float32_matmul_precision(original_precision) + + def rmse(ref, res): """ Calculate root mean squared error @@ -187,6 +212,17 @@ def _trig2(score, b, h, m, n): S = 2048 D = 64 +test_Hq_Hkv = [ + (4, 2), + (4, 1), +] + +test_Bq_Bkv = [ + (3, 1), + (4, 1), + (5, 1), +] + def query_key_value_clones( query: torch.Tensor, @@ -375,7 +411,11 @@ def run_dynamic_test( S: int = S, D: int = D, ): - sdpa_partial = create_attention(score_mod) + # If the seqlen becomes smaller than the seqlen of the previous batch, + # we can still reuse the block_mask created from a larger seqlen. + MAX_S = S + block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S) + sdpa_partial = create_attention(score_mod, block_mask=block_mask) # The first eager batch, shape (B, H, S, D) q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) @@ -449,6 +489,17 @@ def run_dynamic_test( ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) + # The third iteration, shape (B * 2, H, S * 2, D) + # Since seqlen is larger than the seqlen in block_mask, throw errors. + S = int(S * 4) + q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + with self.assertRaisesRegex( + torch._dynamo.exc.BackendCompilerFailed, "Q seqlen must be smaller than" + ): + compiled_sdpa(q3, k3, v3) + def run_automatic_dynamic_test( self, score_mod: Callable, @@ -458,7 +509,9 @@ def run_automatic_dynamic_test( S: int = S, D: int = D, ): - sdpa_partial = create_attention(score_mod) + MAX_S = S + block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S) + sdpa_partial = create_attention(score_mod, block_mask=block_mask) # The first eager batch, shape (B, H, S, D) q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") @@ -526,7 +579,7 @@ def run_automatic_dynamic_test( def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): self.run_test(score_mod, dtype) - @supported_platform + @running_on_a100_or_rocm_only @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( @@ -540,7 +593,7 @@ def test_builtin_score_mods_seqlen_lt_default_sparse_block_size( ) self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D) - @supported_platform + @running_on_a100_or_rocm_only @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_seqlen_lt_custom_sparse_block_size( @@ -558,16 +611,14 @@ def causal_mask(b, h, q, kv): ) self.run_test_with_call(attention, dtype, B, H, 64, D, B, H, 64, D) - @expectedFailure # TODO: supports block sparsity with dynamic shapes @supported_platform - @common_utils.parametrize("dtype", test_dtypes) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable): self.run_dynamic_test(score_mod, dtype) - @expectedFailure # TODO: supports block sparsity with dynamic shapes @supported_platform - @common_utils.parametrize("dtype", test_dtypes) + @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) def test_builtin_score_mods_automatic_dynamic( self, dtype: torch.dtype, score_mod: Callable @@ -593,6 +644,76 @@ def test_builtin_score_mods_different_seqlen( D, ) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("batch_dims", test_Bq_Bkv) + @common_utils.parametrize("head_dims", test_Hq_Hkv) + @common_utils.parametrize("score_mod", test_score_mods) + def test_kv_batch_broadcast( + self, + dtype: torch.dtype, + batch_dims: Tuple[int, int], + head_dims: Tuple[int, int], + score_mod: Callable, + ): + Hq, Hkv = head_dims + assert Hq % Hkv == 0 + + Bq, Bkv = batch_dims + assert Bq > 1 and Bkv == 1 + + self.run_test( + score_mod, + dtype, + Bq, + Hq, + S, + D, + Bkv, + Hkv, + S, + D, + ) + + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("batch_dims", test_Bq_Bkv) + @common_utils.parametrize("head_dims", test_Hq_Hkv) + @common_utils.parametrize("score_mod", test_score_mods) + def test_kv_batch_broadcast_causal_mask( + self, + dtype: torch.dtype, + batch_dims: Tuple[int, int], + head_dims: Tuple[int, int], + score_mod: Callable, + ): + Hq, Hkv = head_dims + assert Hq % Hkv == 0 + + Bq, Bkv = batch_dims + assert Bq > 1 and Bkv == 1 + + def mask_mod(b, h, q, kv): + return q >= kv + + block_mask = create_block_mask(mask_mod, 1, 1, S, S) + attention = functools.partial( + flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv) + ) + + self.run_test_with_call( + attention, + torch.float16, + Bq, + Hq, + S, + D, + Bkv, + Hkv, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("score_mod", test_score_mods) @@ -1375,6 +1496,34 @@ def test_aot_eager_gradcheck(self, score_mod): ) ) + @supported_platform + def test_eager_backward_strides(self): + class Repro(torch.nn.Module): + def __init__(self): + super().__init__() + self.qkv_proj = torch.nn.Linear(256, 256 * 3) + self.n_head = 256 // 64 + self.d_attn = 256 + + def forward(self, x): + n_batch, n_ctx, _ = x.shape + q, k, v = self.qkv_proj(x).split( + [self.d_attn, self.d_attn, self.d_attn], dim=2 + ) + q = q.reshape(n_batch, n_ctx, self.n_head, -1) + k = k.reshape(n_batch, n_ctx, self.n_head, -1) + v = v.reshape(n_batch, n_ctx, self.n_head, -1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + x = torch.nn.attention.flex_attention.flex_attention(q, k, v) + return x + + model = Repro().cuda() + x = torch.randn((1, 512, 256), device="cuda", requires_grad=True) + out = torch.compile(model, backend="aot_eager")(x) + out.backward(torch.ones_like(out)) + @supported_platform def test_differentiable_logsumexp_gradcheck(self): make_tensor = functools.partial( @@ -1433,6 +1582,37 @@ def test_differentiable_logsumexp_compiled(self): v_grad, v_grad2, atol=tolerance.atol, rtol=tolerance.rtol ) + @supported_platform + def test_float32_matmul_precision(self): + make_tensor = functools.partial( + torch.zeros, + (2, 2, 128, 32), + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + query.fill_(0.2) + key.fill_(0.3) + value.fill_(0.4) + + query.requires_grad = True + key.requires_grad = True + value.requires_grad = True + + def score_mod(score, b, h, q, kv): + return score * 2 + + with temp_float32_matmul_precision("highest"): + out_eager = flex_attention(query, key, value, score_mod) + flex_compiled = torch.compile(flex_attention, fullgraph=True) + out_compiled = flex_compiled(query, key, value, score_mod) + + grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) + grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) + + torch.testing.assert_close(grads_eager, grads_compile) + @supported_platform @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) @@ -1481,6 +1661,52 @@ def mask_mod(b, h, q, kv): out = func(query, key, value, block_mask=block_mask) out.sum().backward() + @supported_platform + @common_utils.parametrize("mode", ["eager", "inductor"]) + @common_utils.parametrize( + "permute_order", + [ + (0, 1, 2, 3), # Default order + (1, 0, 2, 3), # Reverse order + (0, 2, 1, 3), # Mixed order + (2, 0, 1, 3), # Another mixed order + ], + ) + @common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)]) + def test_flex_attention_stride_ordering(self, mode, permute_order, shape): + from torch._inductor.ir import get_stride_order + + # Setup + make_tensor = functools.partial( + torch.randn, + shape, + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + + # Create and permute tensors + query, key, value = make_tensor(), make_tensor(), make_tensor() + query = query.permute(permute_order) + key = key.permute(permute_order) + value = value.permute(permute_order) + + if mode == "inductor": + func = torch.compile(flex_attention, backend=mode, fullgraph=True) + else: + func = flex_attention + + out = func(query, key, value) + + out_stride_order = get_stride_order(out.stride()) + query_stride_order = get_stride_order(query.stride()) + + self.assertEqual( + out_stride_order, + query_stride_order, + f"Stride order mismatch: out {out_stride_order}, query {query_stride_order}", + ) + @supported_platform @common_utils.parametrize("compile", [True, False]) def test_fully_masked_out_rows_0_check(self, compile: bool): @@ -1507,7 +1733,7 @@ def mask_mod(b, h, q, kv): ) out, lse = flex(query, key, value, block_mask=block_mask, return_lse=True) self.assertEqual(out[:, :, M:, :].sum(), 0) - self.assertTrue((lse[:, :, M:] == 0.0).all()) + self.assertTrue((lse[:, :, M:] == -float("inf")).all()) loss = out.sum() + lse.sum() loss.backward() @@ -1624,6 +1850,133 @@ def mask_mod(b, h, q, kv): self.run_test_with_call(attention, Q_S=S - 1, KV_S=S - 1) + @supported_platform + def test_force_write_lse(self): + make_tensor = functools.partial( + torch.randn, + (2, 2, 128, 16), + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + out_eager, lse_eager = flex_attention(query, key, value, return_lse=True) + + flex_compile = torch.compile(flex_attention, fullgraph=True) + out_compiled, lse_compiled = flex_compile(query, key, value, return_lse=True) + + torch.testing.assert_close(lse_eager, lse_compiled, atol=3e-3, rtol=0) + + @supported_platform + @common_utils.parametrize("backend", ["flex_attention", "flex_decode", "eager"]) + def test_lse_masked_output(self, backend): + if backend == "flex_decode": + if TEST_WITH_ROCM: + self.skipTest("backend=flex_decode is unsupported on ROCM, for now") + kernel_options = {"FORCE_USE_FLEX_ATTENTION": False} + flex_call = torch.compile(flex_attention, fullgraph=True) + elif backend == "flex_attention": + kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} + flex_call = torch.compile(flex_attention, fullgraph=True) + else: + kernel_options = {} + flex_call = flex_attention + + N_CTX = 96 + SLIDING_WINDOW = 64 + make_tensor = functools.partial( + torch.randn, + (2, 2, N_CTX, 64), + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + + def sliding_window_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + window_mask = q_idx - kv_idx <= SLIDING_WINDOW + return causal_mask & window_mask + + def global_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + window_mask = q_idx - kv_idx > SLIDING_WINDOW + return causal_mask & window_mask + + sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask( + sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX + ) + global_causal = torch.nn.attention.flex_attention.create_block_mask( + global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX + ) + + local_attn = functools.partial( + flex_call, + block_mask=sliding_window_causal, + return_lse=True, + kernel_options=kernel_options, + ) + global_attn = functools.partial( + flex_call, + block_mask=global_causal, + return_lse=True, + kernel_options=kernel_options, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + gradOut = make_tensor(requires_grad=False) + + x_local, lse_local = local_attn(q, k, v) + x_global, lse_global = global_attn(q, k, v) + + max_lse = torch.maximum(lse_local, lse_global) + lse_global = lse_global - max_lse + lse_local = lse_local - max_lse + lse_global = torch.exp(lse_global) + lse_local = torch.exp(lse_local) + x = ((x_local * lse_local[..., None]) + (x_global * lse_global[..., None])) / ( + lse_global[..., None] + lse_local[..., None] + ) + x.backward(gradOut) + flex_q_grad, flex_k_grad, flex_v_grad = q.grad, k.grad, v.grad + q.grad = None + k.grad = None + v.grad = None + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) + out.backward(gradOut) + + torch.testing.assert_close(x, out, atol=3e-3, rtol=2e-3) + torch.testing.assert_close(flex_q_grad, q.grad, atol=3e-3, rtol=2e-3) + torch.testing.assert_close(flex_k_grad, k.grad, atol=3e-3, rtol=2e-3) + torch.testing.assert_close(flex_v_grad, v.grad, atol=3e-3, rtol=2e-3) + + @supported_platform + def test_small_q_kv_len(self): + make_tensor = functools.partial( + torch.ones, + (1, 1, 1, 16), + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + kernel_options = {"FORCE_USE_FLEX_ATTENTION": True} + out_eager, lse_eager = flex_attention( + query, key, value, return_lse=True, kernel_options=kernel_options + ) + + flex_compile = torch.compile(flex_attention, fullgraph=True) + out_compiled, lse_compiled = flex_compile( + query, key, value, return_lse=True, kernel_options=kernel_options + ) + + assert torch.equal(out_eager, out_compiled) + assert torch.equal(lse_eager, lse_compiled) + + grads_eager = torch.autograd.grad(out_eager.sum(), (query, key, value)) + grads_compile = torch.autograd.grad(out_compiled.sum(), (query, key, value)) + + torch.testing.assert_close(grads_eager, grads_compile) + @supported_platform def test_causal_block_non_divisible_with_captured_buffer(self): Q_S = S - 3 @@ -1643,6 +1996,26 @@ def mask_mod(b, h, q, kv): self.run_test_with_call(attention, Q_S=Q_S, KV_S=KV_S) + @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + def test_qkv_and_block_mask_on_the_same_device(self): + make_tensor = functools.partial( + torch.ones, + (2, 2, 256, 32), + device="cuda:0", + dtype=torch.float32, + requires_grad=True, + ) + query, key, value = make_tensor(), make_tensor(), make_tensor() + + def mask_mod(b, h, q, kv): + return q >= kv + + block_mask = create_block_mask(mask_mod, 1, 1, 256, 256, device="cuda:1") + with self.assertRaisesRegex( + RuntimeError, "Expect q/k/v and block_mask to be on the same device" + ): + torch.compile(flex_attention)(query, key, value, block_mask=block_mask) + @supported_platform def test_fw_bw_graph_correctness(self): cnt = CompileCounterWithBackend("aot_eager") @@ -1788,6 +2161,63 @@ def causal_mask(b, h, q, kv): self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity()) self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity()) + @supported_platform + def test_getitem(self): + offset = torch.zeros(8, device="cuda") + + def causal_mask(b, h, q, kv): + return (q + (offset[b] * 128)) >= kv + + block_mask = create_block_mask(causal_mask, 4, 2, 512, 512) + assert block_mask.kv_num_blocks.shape == (4, 2, 4) + assert block_mask.kv_indices.shape == (4, 2, 4, 4) + + # Index on batch dimension + new_block_mask = block_mask[0] + assert new_block_mask.kv_num_blocks.shape == (2, 4) + assert new_block_mask.kv_indices.shape == (2, 4, 4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4, 4) + + # slicing on batch and head dimension + new_block_mask = block_mask[0:2, 1:2] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 4) + assert new_block_mask.kv_indices.shape == (2, 1, 4, 4) + + # slicing on batch, head, and query dimension + new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] + assert new_block_mask.kv_num_blocks.shape == (2, 1, 1) + assert new_block_mask.kv_indices.shape == (2, 1, 1, 4) + + # slicing on batch, head, and query dimension + q_index = torch.tensor([0], dtype=torch.int32) + new_block_mask = block_mask[:, :, q_index] + + self.assertEqual(new_block_mask.kv_num_blocks.ndim, 3) + self.assertEqual(new_block_mask.kv_indices.ndim, 4) + torch.testing.assert_close( + new_block_mask.kv_num_blocks, + block_mask.kv_num_blocks[:, :, q_index], + ) + torch.testing.assert_close( + new_block_mask.kv_indices, block_mask.kv_indices[:, :, q_index, :] + ) + + if block_mask.full_kv_num_blocks is not None: + assert new_block_mask.full_kv_num_blocks is not None + assert new_block_mask.full_kv_indices is not None + torch.testing.assert_close( + new_block_mask.full_kv_num_blocks, + block_mask.full_kv_num_blocks[:, :, q_index], + ) + torch.testing.assert_close( + new_block_mask.full_kv_indices, + block_mask.full_kv_indices[:, :, q_index, :], + ) + @supported_platform def test_block_mask_device_change(self): offset = torch.zeros(8, device="cuda") @@ -1813,6 +2243,16 @@ def causal_mask(b, h, q, kv): assert block_mask.q_indices.is_cuda assert block_mask.q_num_blocks.is_cuda + @supported_platform + def test_compiling_create_block_mask(self): + def mask_mod(b, h, q, kv): + return q >= kv + + block_mask = create_block_mask(mask_mod, 1, 1, 512, 512, _compile=True) + self.assertIsInstance(block_mask, BlockMask) + self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4))) + self.assertEqual(block_mask.kv_indices.shape, torch.Size((1, 1, 4, 4))) + @supported_platform def test_block_mask_viz(self): def causal_mask(b, h, q, kv): @@ -2037,7 +2477,7 @@ def causal_mask(b, h, q_idx, kv_idx): *inputs, is_causal=True ) - torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=1e-3, rtol=0.0) + torch.testing.assert_close(causal_mask_out, sdpa_mask_out, atol=5e-3, rtol=0.0) common_utils.instantiate_parametrized_tests(TestFlexAttention) diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 3cf3981b0f1782..dde47f6a9a2673 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -4,7 +4,7 @@ import functools from collections import namedtuple from contextlib import nullcontext -from typing import Callable, Optional +from typing import Callable, Optional, Tuple from unittest import expectedFailure, skipUnless from unittest.mock import patch @@ -14,12 +14,14 @@ from torch.nn.attention.flex_attention import ( _create_empty_block_mask, _identity, + BlockMask, create_block_mask, flex_attention, ) from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 +from torch.testing._internal.common_utils import skipIfRocm from torch.utils._triton import has_triton @@ -185,6 +187,13 @@ def _trig2(score, b, h, m, n): (16, 16), ] +test_Bq_Bkv = [ + (3, 1), + (5, 1), + (8, 1), + (16, 1), +] + (Hq, Hkv) = (16, 8) @@ -253,7 +262,7 @@ def _check_out( def run_test( self, - score_mod: Callable, + score_mod: Optional[Callable], dtype: torch.dtype = torch.float16, Q_B: int = B, Q_H: int = Hq, @@ -263,7 +272,11 @@ def run_test( KV_H: int = Hkv, KV_S: int = S, V_D: int = D, + block_mask: Optional[BlockMask] = None, ): + assert ( + score_mod is not None or block_mask is not None + ), "Must provide score_mod or block_mask" assert Q_H % KV_H == 0 q = torch.randn( (Q_B, Q_H, Q_S, Q_D), @@ -280,7 +293,6 @@ def run_test( q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) - block_mask = None sdpa_partial = create_attention( score_mod, block_mask, enable_gqa=(not Q_H == KV_H) ) @@ -445,6 +457,37 @@ def test_strided_inputs(self, dtype: torch.dtype, k_s, v_s, head_dims): ref_out, compiled_out, atol=tolerance.atol, rtol=tolerance.rtol ) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes_fast) + @common_utils.parametrize("head_dims", test_Hq_Hkv) + @common_utils.parametrize("batch_dims", test_Bq_Bkv) + @common_utils.parametrize("score_mod", test_score_mods) + def test_kv_batch_broadcast( + self, + dtype: torch.dtype, + head_dims: Tuple[int, int], + batch_dims: Tuple[int, int], + score_mod: Callable, + ): + Hq, Hkv = head_dims + assert Hq % Hkv == 0 + + Bq, Bkv = batch_dims + assert Bq > 1 and Bkv == 1 + + self.run_test( + score_mod, + dtype, + Bq, + Hq, + 1, + D, + Bkv, + Hkv, + S, + D, + ) + @supported_platform @common_utils.parametrize("dtype", test_dtypes) def test_skip_odd_keys(self, dtype: torch.dtype): @@ -525,6 +568,7 @@ def bias_mod(score, b, h, q, kv): self.run_test(bias_mod, dtype) + @skipIfRocm @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) def test_load_from_bias_head_seq_batch(self, dtype): @@ -806,7 +850,7 @@ def mask_mod(b, h, q, kv): query, key, value, block_mask=block_mask, enable_gqa=True, return_lse=True ) self.assertEqual(out[:, :, M:, :].sum(), 0) - self.assertTrue((lse[:, :, M:] == 0.0).all()) + self.assertTrue((lse[:, :, M:] == -float("inf")).all()) loss = out.sum() + lse.sum() loss.backward() @@ -943,6 +987,36 @@ def func(q, k, v, score_mod): code[0] ) + @supported_platform + def test_non_sparse_mulitple_block_size(self): + def generate_causal_offset(offset: torch.Tensor): + def causal_offset_mask(b, h, q_idx, kv_idx): + return (offset + q_idx) >= kv_idx + + return causal_offset_mask + + def noop(score, b, h, q_idx, kv_idx): + return score + + mod = generate_causal_offset( + torch.tensor(192, device="cuda", dtype=torch.int32) + ) + block_mask = create_block_mask(mod, 1, 1, 1, 65) + + self.run_test( + score_mod=None, + dtype=torch.float32, + block_mask=block_mask, + Q_B=1, + Q_H=1, + Q_S=1, + Q_D=16, + KV_B=1, + KV_H=1, + KV_S=65, + V_D=16, + ) + @supported_platform def test_do_not_trigger_dynamic_shapes_on_empty_block_mask(self): torch._dynamo.reset() diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index af15fdcfd38532..e1ba38af845dac 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -21,7 +21,10 @@ try: from .test_torchinductor import check_model, check_model_cuda except ImportError: - from test_torchinductor import check_model, check_model_cuda + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": @@ -487,6 +490,43 @@ def fn(a0, a1, b0, b1): self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", False) + @torch._dynamo.config.patch("assume_static_by_default", False) + @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) + def test_enable_dynamic_shapes_python_wrapper(self, op=torch._foreach_add): + def fn(a0, a1, b0, b1): + return op([a0, a1], [b0, b1]) + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs) + + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_cuda + @torch._dynamo.config.patch("automatic_dynamic_shapes", False) + @torch._dynamo.config.patch("assume_static_by_default", False) + @torch._inductor.config.patch("combo_kernel_foreach_dynamic_shapes", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_enable_dynamic_shapes_cpp_wrapper_cuda(self, op=torch._foreach_add): + def fn(a0, a1, b0, b1): + return op([a0, a1], [b0, b1]) + + inputs = ( + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + torch.rand(10, 10, device="cuda:0"), + torch.rand(20, 20, device="cuda:0"), + ) + + self.check_model_cuda(fn, inputs) + @unittest.skipIf(IS_FBCODE, "cpp compile not supported in fbcode") @bin_ops def test_cpu_cpp_fallback(self, op): diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 806d71b6605bd8..a54a9d71ba8f98 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -27,7 +27,7 @@ raise unittest.SkipTest("requires sympy/functorch/filelock") try: - import halide + import halide # @manual HAS_HALIDE = halide is not None except ImportError: @@ -37,7 +37,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library make_halide = config.patch( diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 6346751e7e9177..88f5530b57870c 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -10,6 +10,7 @@ import torch from torch import nn +from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import override_lowering, run_and_get_code @@ -22,23 +23,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - TEST_WITH_ASAN, - TEST_WITH_ROCM, +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_cuda, + copy_tests, ) - - -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - -from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests +from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM importlib.import_module("functorch") @@ -94,6 +84,17 @@ def forward(self, x): return self.bn(self.conv(x)) +class ConvBNHardswish(torch.nn.Module): + def __init__(self, in_channels, out_channels, bias=False, **kwargs): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) + self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001, dtype=torch.float) + self.hardswish = nn.Hardswish(inplace=True) + + def forward(self, x): + return self.hardswish(self.bn(self.conv(x))) + + class ConvFunctionalBN(torch.nn.Module): def __init__( self, @@ -452,6 +453,54 @@ def test_folded_conv_bn(self): x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) + torch._dynamo.reset() + counters.clear() + + @torch.compile() + def foo(mod, x): + return mod(x) + + # TODO - bias is separate kernel right now, we should only unfuse it + # from conv if it can be fused + + with torch.no_grad(): + out_eager = mod(x) + out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) + + # we unfuse the conv bias, but it should only have one constant in the kernel + if self.device == "cuda": + FileCheck().check_not(".run(").check("conv").check(".run(").check_same( + "frozen_param" + ).check_not("frozen_param").check_next("return").run(code[0]) + + self.assertEqual( + out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 + ) + self.assertEqual(counters["inductor"]["binary_folding"], 4) + + @torch._inductor.config.patch(layout_optimization=False) + def test_folded_conv_bn_hardswish(self): + for use_bias, dtype in itertools.product( + [True, False], [torch.float16, torch.bfloat16, torch.float32] + ): + if self.device == "cpu" and dtype == torch.float16: + continue + + if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + continue + + mod = ( + ConvBNHardswish(3, 32, bias=use_bias, kernel_size=3, stride=2) + .eval() + .to(self.device) + .to(dtype) + ) + + x = torch.rand(3, 3, 32, 32).to(self.device).to(dtype) + + torch._dynamo.reset() + counters.clear() + @torch.compile() def foo(mod, x): return mod(x) @@ -472,6 +521,7 @@ def foo(mod, x): self.assertEqual( out_optimized_for_infernece, out_eager, atol=1e-2, rtol=1e-2 ) + self.assertEqual(counters["inductor"]["binary_folding"], 4) @torch._inductor.config.patch(layout_optimization=False) def test_folded_conv_bn_with_module_sharing(self): diff --git a/test/inductor/test_inplacing_pass.py b/test/inductor/test_inplacing_pass.py index c3aa3996053dce..280bcb25c37d9b 100644 --- a/test/inductor/test_inplacing_pass.py +++ b/test/inductor/test_inplacing_pass.py @@ -3,14 +3,23 @@ from typing import List import torch +import torch._inductor.config as inductor_config from functorch import make_fx from torch import Tensor from torch._dynamo.utils import counters -from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch.testing._internal.common_utils import IS_LINUX -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + IS_LINUX, + parametrize, + subtest, +) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string @@ -18,16 +27,16 @@ const = torch.tensor(0.0) -device = "cuda" +device = GPU_TYPE def num_reinplacing_failures(): return counters["inductor"]["possibly_missed_reinplacing_opportunities"] -@torch.library.custom_op("_reinplacing::sin", mutates_args={"out"}) -def sin(x: torch.Tensor, out: torch.Tensor) -> None: - out.copy_(x.sin()) +@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) +def sin(x: torch.Tensor, result: torch.Tensor) -> None: + result.copy_(x.sin()) @torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}) @@ -36,6 +45,40 @@ def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> No out_cos.copy_(x.cos()) +if HAS_GPU: + import triton # @manual + import triton.language as tl # @manual + + @triton.jit + def sin_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.sin(x) + tl.store(out_ptr + offsets, output, mask=mask) + + def sin_triton(x, out): + n_elements = x.numel() + sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + +else: + + def sin_triton(x, out): + return + + +@torch.library.custom_op("test_view::boo", mutates_args={"x"}) +def boo(x: torch.Tensor) -> None: + x.sin_() + + class TestReinplacingPassCorrectness(InductorTestCase): def setUp(self): counters.clear() @@ -90,12 +133,12 @@ def f(x, y): self._test(f) - def test_counters(self): + def test_counters_functionalize_old(self): counters.clear() def f(x): out = torch.empty_like(x) - _, new_out = auto_functionalized(sin._opoverload, x=x, out=out) + _, new_out = auto_functionalized(sin._opoverload, x=x, result=out) y = out * new_out return new_out, y @@ -109,21 +152,176 @@ def f(x): # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE self.assertEqual(num_reinplacing_failures(), 1) + def test_counters_functionalize_v2(self): + counters.clear() + + def f(x): + out = torch.empty_like(x) + _, new_out = auto_functionalized_v2( + sin._opoverload, + x=x, + _result_base_index=0, + _result_size=(3,), + _result_stride=(1,), + _result_storage_offset=0, + _all_bases=[out], + ) + y = out * new_out + return new_out, y + + x = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x) + reinplace_inplaceable_ops_core(gm.graph) + + # We shouldn't have been able to reinplace `out` because it was used after + # auto_functionalized. Note that this usually doesn't happen in practice; + # we're artificially creating this example to test the counter. + # IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE + self.assertEqual(num_reinplacing_failures(), 1) + + def get_not_inplaced_count(self, graph): + counter = 0 + auto_functionalized_found = False + for node in graph.nodes: + if (node.target == torch.ops.higher_order.auto_functionalized) or ( + node.target == torch.ops.higher_order.auto_functionalized_v2 + ): + auto_functionalized_found = True + counter += len(node.meta["only_clone_these_tensors"]) + assert auto_functionalized_found + return counter + + def test_view_inplaced_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) + return () + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 0) + + # introduce a view another_view that is used `after` the copy + def test_view_inplaced2_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + another_view = arg0_1[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) + return another_view + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 0) + + # introduce a view another_view that is used `before` the copy + def test_views_not_inplaced_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + another_view = arg0_1[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + use_another_view = another_view * 10 + copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_1) + return use_another_view + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) + + # a view over input without copy node, inplace not allowed + def test_views_not_inplaced2_functionalize_v2(self): + def f(arg0_1): + select = torch.ops.aten.select.int(arg0_1, 0, 0) + another_view = arg0_1[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(3,), + _x_stride=(1,), + _x_storage_offset=0, + _all_bases=[arg0_1], + ) + getitem_1 = auto_functionalized[1] + return + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) + + # no copy nodes, view over local, with a use for another view + def test_views_not_inplaced3_functionalize_v2(self): + def f(arg0_1): + a = torch.ones(10) + another_view = a[2] + auto_functionalized = auto_functionalized_v2( + torch.ops.test_view.boo.default, + _x_base_index=0, + _x_size=(), + _x_stride=(), + _x_storage_offset=0, + _all_bases=[a], + ) + getitem_1 = auto_functionalized[1] + return another_view + + x1 = torch.randn(3, device=device) + gm = make_fx(f, tracing_mode="fake")(x1) + reinplace_inplaceable_ops_core(gm.graph) + + self.assertEqual(self.get_not_inplaced_count(gm.graph), 1) + def test_multi_output_intermediate(self): for requires_grad in [False, True]: - counters.clear() - - def f(x): - out1 = torch.empty_like(x) - out2 = torch.empty_like(x) - sin_cos(x, out1, out2) - return out1, out2, x**2 - - x = torch.randn(3, device=device, requires_grad=requires_grad) - res1, res2, _ = torch.compile(f)(x) - self.assertEqual(res1, x.sin()) - self.assertEqual(res2, x.cos()) - self.assertEqual(num_reinplacing_failures(), 0) + for enable_v2 in [False, True]: + with inductor_config.patch( + {"enable_auto_functionalized_v2": enable_v2} + ): + counters.clear() + + def f(x): + out1 = torch.empty_like(x) + out2 = torch.empty_like(x) + sin_cos(x, out1, out2) + return out1, out2, x**2 + + x = torch.randn(3, device=device, requires_grad=requires_grad) + res1, res2, _ = torch.compile(f)(x) + self.assertEqual(res1, x.sin()) + self.assertEqual(res2, x.cos()) + self.assertEqual(num_reinplacing_failures(), 0) def test_multiple_mutations(self): counters.clear() @@ -156,33 +354,102 @@ def f(x): self.assertEqual(result, x.sin().sin().sin()) self.assertEqual(num_reinplacing_failures(), 0) - def test_lists(self): - @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) - def mutate_op(y: List[Tensor]) -> None: - y[0].add_(2) - y[1].add_(3) - - @torch.compile(fullgraph=True, dynamic=False, backend="inductor") - def f(b): - mutate_op([b[0], b[1]]) + def test_lists_functionalize_v2(self): + with inductor_config.patch({"enable_auto_functionalized_v2": True}): + + @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) + def mutate_op(y: List[Tensor]) -> None: + y[0].add_(2) + y[1].add_(3) + + @torch.compile(fullgraph=True, dynamic=False, backend="inductor") + def f(b): + mutate_op([b[0], b[1]]) + + x1 = torch.tensor([0.3, 0.4], device=device) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(x1) + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # We can inplace the base y. no clones emitted. + self.assertEqual(num_reinplacing_failures(), 0) + self.assertEqual(post_grad_graphs.count("aten.clone"), 0) + + def test_lists_old_functionalize(self): + with inductor_config.patch({"enable_auto_functionalized_v2": False}): + + @torch.library.custom_op("mylib::mutate_op", mutates_args={"y"}) + def mutate_op(y: List[Tensor]) -> None: + y[0].add_(2) + y[1].add_(3) + + @torch.compile(fullgraph=True, dynamic=False, backend="inductor") + def f(b): + mutate_op([b[0], b[1]]) + + x1 = torch.tensor([0.3, 0.4], device=device) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + torch.compile(f, backend="inductor", fullgraph=True)(x1) + post_grad_graphs = "\n".join( + log_stream.getvalue().strip().split("\n")[3:] + ).strip() + + # Can't reinplace on views yet (1 for the "entire list" failing to reinplace) + self.assertEqual(num_reinplacing_failures(), 1) + + # Both list inputs failed to reinplace. So we should have emitted clones for them. + self.assertEqual(post_grad_graphs.count("aten.clone"), 2) + + @parametrize( + "factory_op", + [ + subtest(torch.ones_like, name="ones_like"), + subtest(torch.empty_like, name="empty_like"), + ], + ) + @parametrize( + "sin_op", + [ + subtest(sin, name="sin_op"), + subtest(sin_triton, name="sin_triton"), + ], + ) + def test_partitioner_recomputes_factory(self, factory_op, sin_op): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = factory_op(x) + sin_op(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = factory_op(grad) + sin_op(saved, out) + return out + + @torch.compile(backend="inductor") + def f(x): + return MySin.apply(x) - x1 = torch.tensor([0.3, 0.4], device=device) - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - with ctx(): - torch.compile(f, backend="inductor", fullgraph=True)(x1) - post_grad_graphs = "\n".join( - log_stream.getvalue().strip().split("\n")[3:] - ).strip() + x = torch.randn(3, requires_grad=True, device=device) + y = f(x) + self.assertEqual(num_reinplacing_failures(), 0) - # Can't reinplace on views yet (1 for the "entire list" failing to reinplace) - self.assertEqual(num_reinplacing_failures(), 1) - # Both list inputs failed to reinplace. So we should have emitted clones for them. - self.assertEqual(post_grad_graphs.count("aten.clone"), 2) +instantiate_parametrized_tests(TestReinplacingPassCorrectness) if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 0e90b4a4c708d7..f0d931ed41994b 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -1,29 +1,203 @@ # Owner(s): ["module: inductor"] +import contextlib +import unittest + +import numpy as np + import torch +from torch import nn from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same -from torch._inductor import config as inductor_config, metrics +from torch._inductor import config as inductor_config, ir, metrics +from torch._inductor.codegen.triton import TritonScheduling +from torch._inductor.graph import GraphLowering +from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.test_operators import realize +from torch._inductor.utils import sympy_index_symbol +from torch._inductor.virtualized import ops, V +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.utils._pytree import tree_map +from torch.utils._sympy.functions import ModularIndexing if HAS_CUDA: torch.set_default_device("cuda") +class MockScheduler: + available_buffer_names = () + + @staticmethod + def get_backend(cls, *args): + return TritonScheduling(cls) + + +@inductor_config.patch(loop_ordering_after_fusion=True) +class ImplDetailTest(TestCase): + _exit_stack = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + + gm = torch.fx.symbolic_trace(lambda: 0) + graph = GraphLowering(gm) + graph.scheduler = MockScheduler + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context(V.set_graph_handler(graph)) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._exit_stack.close() + + @staticmethod + def _get_snode_body_sym_prefix(snode): + body = snode._body + prefix = "" + + for var in body.var_ranges: + prefix = str(var)[0] + break + + assert prefix + return prefix + + @staticmethod + def _create_computed_buffer_ax2(sizes=(32, 64), strides=None): + """ + Create a ComputedBuffer for 'a x 2' + """ + if strides is None: + strides = ir.FlexibleLayout.contiguous_strides(sizes) + + box_a = ir.TensorBox.create( + ir.Buffer( + "a", ir.FixedLayout(torch.device("cuda"), torch.float32, sizes, strides) + ) + ) + box_a_loader = box_a.make_loader() + + def inner_fn(index): + return box_a_loader(index) * 2 + + buf = ir.Pointwise.create( + device=box_a.get_device(), + dtype=box_a.get_dtype(), + inner_fn=inner_fn, + ranges=box_a.get_size(), + ) + buf.realize() + computed_buf = buf.data.data + computed_buf.decide_layout() + return computed_buf + + def test_reorder_twice(self): + """ + This may happen in practice if we pick a order when fusing A and B. + Then we pick another order for AB when we fusion C into it. + + E.g. happens for BertForMaskedLM. + """ + + buf = self._create_computed_buffer_ax2() + snode = SchedulerNode(V.graph.scheduler, buf) + snode.apply_new_loop_order([1, 0]) + prefix1 = self._get_snode_body_sym_prefix(snode) + self.assertTrue(prefix1 == "z") + snode.apply_new_loop_order([1, 0]) + prefix2 = self._get_snode_body_sym_prefix(snode) + self.assertTrue(prefix2 == "z") + + def test_reorder_and_merge_loops(self): + sizes = (1024, 2048) + strides = (1, 1024) + buf = self._create_computed_buffer_ax2(sizes, strides) + old_sizes, old_body = buf.simplify_and_reorder() + + # Make sure loop reordering happens here + self.assertTrue(tuple(old_sizes[0]) == tuple(reversed(sizes)), f"{old_sizes=}") + new_body = old_body.merge_loops() + new_sizes = new_body.sizes + self.assertTrue(tuple(new_sizes[0]) == (np.prod(sizes),), f"{new_sizes=}") + + def test_reorder_modular_indexing(self): + """ + There was a bug that we wrongly map i0 to the dimension with size 49 + when reordering the loop and cause ModularIndexing get optimized away + as an no-op. + """ + + def _create_computed_buffer(): + def inner_fn(index): + i0, i1, i2, i3 = index + return ops.load( + "primal", i3 + 49 * i2 + 2401 * ModularIndexing(i0, 1, 64) + ) + + buf = ir.Pointwise.create( + device=torch.device("cuda"), + dtype=torch.float32, + inner_fn=inner_fn, + ranges=[128, 4, 49, 49], + ) + buf.realize() + cbuf = buf.data.data + cbuf.decide_layout() + return cbuf + + buf = _create_computed_buffer() + _, body = buf.simplify_and_reorder() + new_body = body.reorder_iter_loops([1, 2, 3, 0]) + + z0, z1, z2, z3 = (sympy_index_symbol(f"z{i}") for i in range(4)) + self.assertEqual(body.var_ranges, {z0: 128, z1: 4, z2: 49, z3: 49}) + self.assertEqual( + body.indexing_exprs["index0"], + z3 + 49 * z2 + 2401 * ModularIndexing(z0, 1, 64), + ) + self.assertEqual(new_body.var_ranges, {z0: 4, z1: 49, z2: 49, z3: 128}) + self.assertEqual( + new_body.indexing_exprs["index0"], + z2 + 49 * z1 + 2401 * ModularIndexing(z3, 1, 64), + ) + + @inductor_config.patch( { "benchmark_kernel": True, + "loop_ordering_after_fusion": True, "triton.unique_kernel_names": True, } ) class LoopOrderingTest(TestCase): - def do_acc_test(self, f, *args): + def do_acc_test(self, f, *args, cast_fp8=True): expect = f(*args) actual = torch.compile(f)(*args) + + if cast_fp8: + + def _cast(x): + if isinstance(x, torch.Tensor) and x.dtype in ( + torch.float8_e5m2, + torch.float8_e4m3fn, + ): + return x.to(torch.float32) + return x + + # Wordaround the issue that call allclose on fp8 tensor triggers error + # RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn' + expect = tree_map(_cast, expect) + actual = tree_map(_cast, actual) self.assertTrue(same(expect, actual, tol=1e-3)) + def setUp(self): + super().setUp() + metrics.reset() + def test_for_reordering_reindex(self): """ ComputedBuffer.iter_reoredering_reindex can cause some fusion @@ -54,6 +228,171 @@ def f(x, y): expected_num_bytes *= x.itemsize self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) + def test_apbt_realize(self): + M = 1024 + N = 2048 + + def f(x, y): + """ + There will be 2 kernels being generated without loop ordering after fusion: + https://gist.github.com/shunting314/44df83f71de2c110232c50ac6638ed69 + """ + x = realize(x * 2) + y = realize(y * 3) + return x + y + + x = torch.randn(M, N) + y = torch.randn(N, M).t() + + self.do_acc_test(f, x, y) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_sum_and_t(self): + N = 1024 + + def f(x): + return x.sum(dim=-1), x.t().contiguous() + + x = torch.randn(N, N * 2) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_pw_outer_red(self): + def f(x): + x = realize(x + 1) + return x.sum(dim=[0, 1]) + + # make the first 2 dimension small so we don't split the reduction + x = torch.randn(2, 4, 512) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_pw_outer_red_2(self): + """ + The pointwise kernel is a fused kernel + """ + + def f(x): + x = realize(x + 1) + x = realize(x - 2) + x = realize(x * 3) + return x.sum(dim=[0, 1]) + + # make the first 2 dimension small so we don't split the reduction + x = torch.randn(2, 4, 512) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + @inductor_config.patch(split_reductions=False) + def test_different_reduction_order(self): + """ + We should not reorder loops in this case. Since reordering loops does + not help! + """ + + def f(x): + return x.sum(dim=0), x.sum(dim=1) + + x = torch.randn(1024, 2048) + self.do_acc_test(f, x) + self.assertEqual(2, metrics.generated_kernel_count) + self.assertEqual(0, metrics.num_loop_reordering) + + def test_keep_fake_dep(self): + """ + In this model, there are fake dependencies (StarDep) between Scatter + and a following mutation kernel that computes the gradients of + the embedding tables. + + When we do loop reordering for the mutation kernel, we re-analyze + the node's dependencies. But the analysis result does not contains + those fake dependencies. Have to add them back manually. + """ + V = 2048 + hidden_size = 64 + max_seqlen = 512 + batch_size = 8 + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.word_embeddings = nn.Embedding(V, hidden_size) + self.position_embeddings = nn.Embedding(max_seqlen, hidden_size) + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, input_ids, labels, position_ids): + emb = self.word_embeddings(input_ids) + self.position_embeddings( + position_ids + ) + return self.layer_norm(emb) + + m = Model() + + @torch.compile + def f(*args): + m(*args).sum().backward() + + input_ids = torch.randint(0, V, (batch_size, max_seqlen)) + labels = torch.randint(0, V, (batch_size, max_seqlen)) + position_ids = torch.arange(max_seqlen)[None, :] + # Make sure this line does not raise exceptions. If we miss + # fake dependencies after loop reordering, we may get exception that + # some buffer is used before being defined. + f(input_ids, labels, position_ids) + + def test_different_broadcast_shapes(self): + def f(x, y, c): + return x + c, y + c + + x = torch.randn(4, 256, 1024) + y = torch.randn(2, 512, 1024) + c = torch.randn(1024) + self.do_acc_test(f, x, y, c) + + # The two kernels are not fused due to c is broadcasted + self.assertEqual(2, metrics.generated_kernel_count) + + def test_view(self): + """ + Passing this test relies that we compare normalized MemoryDep. + Normlaization here means merging contiguous loops. + + To make loop reordering work, we don't merge loops when creating + SchedulerNode. Thus we need explicitly normalize MemoryDep when + we check if two MemeoryDep matches. + """ + + def f(x): + y = x.sin() + x = realize(x.view(10, 10)) + return x, y + + x = torch.randn(100) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") + def test_fp8_cast_and_t(self): + """ + This test repros the not able to fuses issue in + https://github.com/pytorch/pytorch/issues/130015 + for fp8 cast and transpose + """ + + def f(x, scale): + x = x * scale + x = x.clamp(-1 * E4M3_MAX_POS, E4M3_MAX_POS) + x = x.to(torch.float8_e4m3fn) + x_t = x.t().contiguous().t() + return x, x_t + + x = torch.randn(4096, 4096, dtype=torch.bfloat16) + scale = torch.Tensor([10.0]).cuda() + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + + self.do_acc_test(f, x, scale) + self.assertEqual(1, metrics.generated_kernel_count) + if __name__ == "__main__": if HAS_CUDA: diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 61e21189a9b51a..8f5eacc0c14ae6 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -35,9 +35,9 @@ try: - from .mock_cache import PatchCaches + from .mock_cache import global_stats, PatchCaches except ImportError: - from mock_cache import PatchCaches # @manual + from mock_cache import global_stats, PatchCaches # @manual torch.set_float32_matmul_precision("high") @@ -258,7 +258,7 @@ def no_lookup( op: str, inputs: str, benchmark: Callable[[Any], Dict[ChoiceCaller, float]], - ) -> Dict[ChoiceCaller, float]: + ) -> Optional[Dict[ChoiceCaller, float]]: if benchmark is not None: return benchmark(choices) @@ -377,6 +377,22 @@ def fn(a, b, c): fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm + @fresh_inductor_cache() + @config.patch(search_autotune_cache=True) + def test_search_autotune_cache(self): + def fn(a, b, c): + a = (a @ b) @ c + a, b, c = (t.to(torch.float16) for t in [a, b, c]) + return (a @ b) @ c + + fn_c = torch.compile()(fn) + inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] + from torch._dynamo.utils import counters + + self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) + self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @skipIfRocm @fresh_inductor_cache() @config.patch(max_autotune=True, max_fusion_size=2) @@ -723,9 +739,6 @@ def tearDown(self): def test_max_autotune_remote_caching(self, dynamic: bool): from unittest.mock import patch - if not config.is_fbcode(): - self.skipTest("Redis for autotune is currently broken") - def mm(a, b): a = torch.sin(a) return a @ b @@ -756,22 +769,20 @@ def f(x, y): torch.compile(mm, dynamic=dynamic)(a, b) reset() - PatchCaches.update() - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 3) - self.assertEqual(PatchCaches.num_get_miss, 1) - self.assertEqual(PatchCaches.num_put, 1) + global_stats.report() + self.assertEqual(global_stats.autotune.num_get_hit, 3) + self.assertEqual(global_stats.autotune.num_get_miss, 1) + self.assertEqual(global_stats.autotune.num_put, 1) - PatchCaches.reset() + global_stats.reset() for _ in range(4): with fresh_inductor_cache(): torch.compile(f, dynamic=dynamic)(x, y) reset() - PatchCaches.update() - PatchCaches.report() - self.assertEqual(PatchCaches.num_get_hit, 3) - self.assertEqual(PatchCaches.num_get_miss, 1) - self.assertEqual(PatchCaches.num_put, 1) + global_stats.report() + self.assertEqual(global_stats.autotune.num_get_hit, 3) + self.assertEqual(global_stats.autotune.num_get_miss, 1) + self.assertEqual(global_stats.autotune.num_put, 1) class TestBenchmarkRequest(BenchmarkRequest): @@ -796,6 +807,7 @@ def benchmark( if not self.multi_device: assert visible_devices == self.parent_visible_devices else: + assert self.parent_visible_devices is not None valid_devices = self.parent_visible_devices.split(",") assert visible_devices in valid_devices diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index df125324e89723..d3e07670492123 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -84,7 +84,9 @@ def test_abi_compatible(self): try: from .test_aot_inductor import AOTIRunnerUtil except ImportError: - from test_aot_inductor import AOTIRunnerUtil + from test_aot_inductor import ( + AOTIRunnerUtil, # @manual=fbcode//caffe2/test/inductor:test_aot_inductor-library + ) f, args = self._generate(device="cuda") dim0_x = Dim("dim0_x", min=1, max=2048) diff --git a/test/inductor/test_minifier_isolate.py b/test/inductor/test_minifier_isolate.py index 34b9c9f3383666..61cf6e39611336 100644 --- a/test/inductor/test_minifier_isolate.py +++ b/test/inductor/test_minifier_isolate.py @@ -7,12 +7,12 @@ IS_JETSON, IS_MACOS, skipIfRocm, + skipIfWindows, + skipIfXpu, TEST_WITH_ASAN, ) -from torch.testing._internal.inductor_utils import HAS_CUDA - - -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.triton_utils import requires_gpu # These minifier tests are slow, because they must be run in separate @@ -33,14 +33,18 @@ def inner(x): @unittest.skipIf(IS_JETSON, "Fails on Jetson") @inductor_config.patch("cpp.inject_relu_bug_TESTING_ONLY", "runtime_error") + @skipIfWindows( + msg="Build Failed: fatal error C1083: Cannot open include file: 'Python.h': No such file or directory" + ) def test_after_aot_cpu_runtime_error(self): self._test_after_aot_runtime_error("cpu", "") @skipIfRocm - @requires_cuda + @skipIfXpu + @requires_gpu @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error") - def test_after_aot_cuda_runtime_error(self): - self._test_after_aot_runtime_error("cuda", "device-side assert") + def test_after_aot_gpu_runtime_error(self): + self._test_after_aot_runtime_error(GPU_TYPE, "device-side assert") if __name__ == "__main__": diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 73fa12fe329778..259975996bd90e 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -733,7 +733,9 @@ def __init__( super().__init__() self.conv = torch.nn.Conv2d(3, 128, kernel_size=3, stride=1) self.unary_fn = copy.deepcopy(unary_op) - self.conv2 = torch.nn.Conv2d(128, 128, kernel_size=3, stride=1) + self.conv2 = torch.nn.Conv2d( + 128, 128, kernel_size=3, stride=1, bias=False + ) self.unary_fn2 = copy.deepcopy(unary_op) def forward(self, x): @@ -889,8 +891,8 @@ def __init__( self.conv2 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) self.add_fn = add_fn self.relu = torch.nn.ReLU() - self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) - self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1) + self.conv3 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) + self.conv4 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) self.add_fn2 = add_fn self.relu2 = torch.nn.ReLU() self.use_relu = use_relu @@ -1800,12 +1802,17 @@ def matcher_check_fn(): mod, (v,), [ + "torch.ops.onednn.qlinear_pointwise.tensor", + "torch.ops.onednn.qlinear_pointwise.binary", + ] + if config.abi_compatible + else [ "op_onednn_qlinear_pointwise_tensor.call", "op_onednn_qlinear_pointwise_binary_tensor.call", ], [], check_quantization=True, - num_include_ops=[2, 2], + num_include_ops=[4, 4] if config.abi_compatible else [2, 2], ) else: # For python wrapper @@ -2578,19 +2585,36 @@ def forward(self, x): @skipIfNoDynamoSupport def test_woq_int8(self): class M(torch.nn.Module): + def __init__(self, is_permute): + super().__init__() + self.is_permute = is_permute + def forward(self, x, weight, scales): - return torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales + if self.is_permute: + weight = weight.t() + m = torch.mm( + x.reshape(-1, x.shape[-1]), + weight.to(x.dtype), + ) + y = m * scales.to(m.dtype) + y = y.reshape(*x.shape[:-1], y.shape[-1]) + return y + else: + return ( + torch.nn.functional.linear(x, weight.to(dtype=x.dtype)) * scales + ) - mod = M().eval() x_shape = (1, 1, 256) - w_shape = (12, 256) s_shape = 12 x_strides = [ (256, 256, 1), # linear dispatching to mm (256, 32, 1), # linear dispatching to bmm ] - for x_stride in x_strides: + is_permutes = [False, True] + for x_stride, is_permute in itertools.product(x_strides, is_permutes): + mod = M(is_permute=is_permute).eval() x = torch.randn(x_shape, dtype=torch.bfloat16).as_strided(x_shape, x_stride) + w_shape = (12, 256) w = torch.randint(-128, 127, w_shape, dtype=torch.int8) s = torch.randn(s_shape, dtype=torch.bfloat16) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c8170c92327166..9ae3dd3a125df8 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -3,6 +3,7 @@ import functools import os import unittest +from typing import Tuple import torch from torch import nn, Tensor @@ -12,8 +13,13 @@ from torch._inductor import config, ir, metrics from torch._inductor.fx_passes import pad_mm as pad_mm_pass from torch._inductor.runtime.benchmarking import benchmarker -from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import requires_cuda, serialTest +from torch._inductor.utils import ceildiv, run_and_get_code +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + requires_cuda, + serialTest, +) from torch.testing._internal.inductor_utils import HAS_CUDA @@ -362,6 +368,7 @@ def test_longformer_small_bs(self): self.test_longformer(bs=2) +@instantiate_parametrized_tests class PaddingTest(TestCaseBase): @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") def test_mm_padding_perf(self): @@ -654,6 +661,53 @@ def test_pad_channels_last(self): out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32) self.assertTrue(in_strides == out_strides) + @parametrize("alignment_bytes", (32, 128)) + @parametrize("shape", [(21, 19), (3, 5, 71)]) + @parametrize("dtype", (torch.float16, torch.float32)) + def test_pad_outputs( + self, dtype: torch.dtype, shape: Tuple[int], alignment_bytes: int + ): + """ + Tests padding output tensors to a specific alignment. + This is enabled by a config flag. + """ + func = torch.add + inputs = tuple(torch.randn(*shape, dtype=dtype) for input_idx in range(2)) + + # Compile and run + with config.patch( + { + "comprehensive_padding": True, + "padding_alignment_bytes": alignment_bytes, + "padding_stride_threshold": 0, + "pad_outputs": True, + } + ): + compiled_func = torch.compile(func) + compiled_out = compiled_func(*inputs) + + # Check numerics + eager_out = func(*inputs) + self.check_close(eager_out, compiled_out) + + # Compute the expected padding + element_size = torch.tensor([], dtype=dtype).element_size() + self.assertGreater(alignment_bytes, element_size) + self.assertEqual(alignment_bytes % element_size, 0) + alignment_elements = alignment_bytes // element_size + contiguous_stride = inputs[0].stride() + expected_stride = [1] + for dim in reversed(shape[1:]): + slice_size = dim * expected_stride[0] + new_stride = alignment_elements * ceildiv(slice_size, alignment_elements) + expected_stride.insert(0, new_stride) + expected_stride = tuple(expected_stride) + self.assertNotEqual(expected_stride, contiguous_stride) + + # Check strides + self.assertFalse(compiled_out.is_contiguous()) + self.assertEqual(compiled_out.stride(), expected_stride) + if __name__ == "__main__": if HAS_CUDA: diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 3ad1b5a6ab1fef..e63e8fcfb4e96d 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -66,6 +66,7 @@ def common( additional_check(codes) counters.clear() + @inductor_config.patch(max_autotune_gemm=True) def test_mm_plus_mm(self): def fn(a, b, c, d): return torch.add(torch.mm(a, b), torch.mm(c, d)) @@ -748,6 +749,7 @@ def fn(a, b): torch.randn(2, 8, device="cuda"), torch.randn(2, 16, device="cuda"), ] + torch._dynamo.reset() counters.clear() expected = fn(*args) actual = torch.compile(fn)(*args) @@ -1446,7 +1448,7 @@ def check(type, func_name, args, kwargs, expect=True): (t, [64, 128, 8, 8]), {"dim": 1, "out": [t, t, t, t]}, ) - check("call_function", torch.ops.fsdp.set_, (t, t), {}) + check("call_function", torch.ops.fsdp.copy_, (t, t), {}) check( "call_function", torch.ops.aten.__rshift__.Scalar, (t, 2), {}, expect=False ) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index a7530f71523c54..7de94642f31dde 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -32,6 +32,9 @@ if HAS_CUDA: + import triton # @manual + import triton.language as tl # @manual + from torch.testing._internal.triton_utils import add_kernel aten = torch.ops.aten @@ -858,6 +861,99 @@ def f(a, b): inp = (T(10, 10), TI(2, mx=5)) self.assertExpectedInline(count_numel(f, *inp), """42""") + @requires_cuda + def test_inplace_triton_kernel_training(self): + @triton.jit + def sin_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.sin(x) + tl.store(out_ptr + offsets, output, mask=mask) + + def sin_triton(x, out): + n_elements = x.numel() + sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + + factory_op = torch.empty_like + + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = factory_op(x) + sin_triton(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = factory_op(grad) + sin_triton(saved, out) + return out + + def f(x): + return MySin.apply(x) + + x = T(3, grad=True) + self.assertExpectedInline(count_numel_train(f, x), """9""") + + @requires_cuda + def test_inplace_custom_op_training_two_mutated_inputs(self): + @torch.library.custom_op( + "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"} + ) + def sin_cos( + x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor + ) -> None: + out_sin.copy_(x.sin()) + out_cos.copy_(x.cos()) + + def f(x): + out0 = torch.empty_like(x) + out1 = torch.empty_like(x) + sin_cos(x, out0, out1) + return x.clone(), out0, out1 + + x = T(3, grad=True) + self.assertExpectedInline(count_numel(f, x), """21""") + + @requires_cuda + def test_inplace_custom_op_training(self): + @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"}) + def sin(x: torch.Tensor, result: torch.Tensor) -> None: + result.copy_(x.sin()) + + factory_op = torch.empty_like + + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = factory_op(x) + sin(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = factory_op(grad) + sin(saved, out) + return out + + def f(x): + return MySin.apply(x) + + x = T(3, grad=True) + self.assertExpectedInline(count_numel_train(f, x), """9""") + @requires_cuda def test_inplace_custom_op(self): with torch.library._scoped_library("mylib", "FRAGMENT") as m: diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index e4af44c761d486..016ee768f890c6 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -172,7 +172,7 @@ def fn(x, y): @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") def test_inductor_profiling_triton_hooks(self): - from triton.compiler import CompiledKernel + from triton.compiler import CompiledKernel # @manual hooks_called = {"enter": False, "exit": False} diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 2c32b01be2d60f..3e775ef2de8e4b 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -27,7 +27,12 @@ def patch(f): class TestSplitCatFxPasses(TestCase): - @patch + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "normalization_pass": {}, + }, + post_grad_fusion_options={}, + ) def test_split_normalization(self): def arg_only(x): return [torch.relu(s) for s in torch.split(x, 2, 1)] @@ -76,27 +81,31 @@ def unequal_split_cm(x): def cm_with_list(x): return [torch.relu(s) for s in x.split([16, 16], dim=-1)] + def normalize_reshape_with_dynamic_shape(x): + return x.reshape(4, 16) + args = [ torch.randn(2, 32), ] - for fn, expected_split_norm_count in [ - (arg_only, 1), - (arg_only_dim0, 1), - (kwarg1, 1), - (kwarg2, 1), - (kwarg3, 1), - (list_replace, 0), - (multi_split, 17), - (unequal_split, 1), - (arg_only_cm, 1), - (kwarg1_cm, 1), - (kwarg2_cm, 1), - (multi_split_cm, 17), - (unequal_split_cm, 1), - (cm_with_list, 1), + for fn, dynamic, expected_split_norm_count in [ + (arg_only, False, 1), + (arg_only_dim0, False, 1), + (kwarg1, False, 1), + (kwarg2, False, 1), + (kwarg3, False, 1), + (list_replace, False, 0), + (multi_split, False, 17), + (unequal_split, False, 1), + (arg_only_cm, False, 1), + (kwarg1_cm, False, 1), + (kwarg2_cm, False, 1), + (multi_split_cm, False, 17), + (unequal_split_cm, False, 1), + (cm_with_list, False, 1), + (normalize_reshape_with_dynamic_shape, True, 0), ]: expected = fn(*args) - actual = torch.compile(fn)(*args) + actual = torch.compile(fn, dynamic=dynamic)(*args) torch.testing.assert_close(actual, expected) self.assertEqual( diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bc22dc3852a121..f2240ff64922d4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -516,7 +516,13 @@ def reference_to_expect(actual_flat, correct_flat): correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads) all_none_grads = all(x is None for x in correct_grad) - if all_none_grads: + tensor_args = [ + x + for x in pytree.tree_flatten(example_inputs)[0] + if isinstance(x, torch.Tensor) + ] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all_none_grads and any_non_leaves: # See Note [Detaching inputs that never need gradients] # There are a handful of ops that can return None gradients, into of zero gradients. # If all inputs to an AOTAutograd graph are supposed to get None gradients, @@ -665,7 +671,7 @@ def _run_and_assert_no_indirect_indexing( ) if has_wrapping is not None: test_case.assertTrue( - ("where" in code or "?" in code) is has_wrapping, + ("where" in code or ") ? (" in code) is has_wrapping, msg=f"Wanted {has_wrapping=} but got\n{code}", ) test_case.assertTrue( @@ -1517,7 +1523,7 @@ def test( pass # no device asserts in halide elif self.device == "cpu": _, code = run_and_get_cpp_code(fn_opt, *inps) - self.assertTrue(("?" in code or "blendv" in code) is has_wrapping) + self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping) self.assertTrue(("TORCH_CHECK" in code) is has_assert) # Assert that we always vectorize the kernel regardless of wrapping / checks self.assertTrue(("loadu" in code) is vectorize) @@ -1878,6 +1884,21 @@ def fn(a, b): b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device) self.common(fn, (a, b), rtol=1e-4, atol=1e-5, check_lowp=False) + @config.patch(max_autotune_pointwise=True) + def test_split_cumsum_index(self): + # Split scan uses a workspace that needs to be zeroed before use. + # data[index] does indirect indexing that should catch issues if the + # workspace is not zeroed. + def fn(lengths, data): + offsets = torch.cumsum(lengths, 0) + return data[offsets] + + lengths = torch.full((2**14,), 2**2, dtype=torch.int64, device=self.device) + lengths[-2] = 3 + lengths[-1] = 3 + data = make_tensor((2**16,), dtype=torch.float32, device=self.device) + self.common(fn, (lengths, data)) + def test_split_cumprod(self): def fn(a): return torch.cumprod(a, -1) @@ -7383,6 +7404,7 @@ def fn(a, b): b = torch.empty(0) self.common(fn, [a, b]) + @with_tf32_off def test_slice_scatter_reinplace(self): class M(nn.Module): def __init__(self, device): @@ -7969,7 +7991,6 @@ def g(x): out = [torch.empty_like(x) for _ in range(2)] y = g(x) - @expectedFailureXPU def test_functionalize_rng_wrappers(self): # Ideally, we would like to use torch.compile for these operators. But # currently the plan is to introduce these operators at the partitioner @@ -8135,6 +8156,17 @@ def fn(x): self.common(fn, [torch.zeros([20, 20])]) + @config.patch(check_stack_no_cycles_TESTING_ONLY=True) + def test_check_stack_no_cycles(self): + @torch.compile() + def fn(x): + return x * 3 + + r = fn(torch.randn(2, device=self.device, requires_grad=True)) + # Backward compilation isn't hooked into cprofile, it probably + # should... + # r.sum().backward() + def test_like_rands2(self): # rand_like with kwargs `device` of str type d = self.device @@ -10596,6 +10628,7 @@ def test_mutable_custom_op_fixed_layout2(self): lib.define( "bar(Tensor x, bool is_compiling) -> Tensor", + tags=torch.Tag.flexible_layout, ) bar_strides = [] @@ -10632,8 +10665,8 @@ def fn(x): with torch.no_grad(): self.common(fn, (inp,), check_lowp=False) - # Dynamic shapes invalidate this test case - if torch._dynamo.config.assume_static_by_default: + # Dynamic shapes and rocm invalidate this test case + if torch._dynamo.config.assume_static_by_default and not TEST_WITH_ROCM: # For this test to be valid, Inductor must have changed the conv # to be channels-last. If this assertion ever fails then we need # a new test case. @@ -10875,6 +10908,7 @@ def fn(x, y): fn(a, b) # Skipped on ROCm until https://github.com/ROCm/triton/issues/443 resolved + @slowTest def test_fuse_large_params(self): def pt2_optimizer_step(optimizer): @torch.compile() @@ -11040,7 +11074,6 @@ def fn(a, b): actual = torch.compile(fn)(a, b) self.assertEqual(ref, actual) - @skipIfWindows(msg="torch._dynamo.exc.BackendCompilerFailed") # TODO: FIX IT def test_randint_int64_mod(self): # This used to not compile due to a wrong return type of randint64_cpu # See https://github.com/pytorch/pytorch/issues/117435 @@ -11930,7 +11963,7 @@ def func(a, b): with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32): func_opt = torch._dynamo.optimize("inductor")(func) code = run_and_get_triton_code(func_opt, *inps) - fp32_cast_in_code = "float32" in code + fp32_cast_in_code = "to(tl.float32)" in code self.assertEqual(fp32_cast_in_code, upcast_to_fp32) @config.patch("triton.use_block_ptr", False) diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index e963eabe2571b0..729d368a1e5220 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -2,17 +2,11 @@ import importlib import os import sys -import unittest import torch from torch._inductor.compile_fx import compile_fx from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import ( - IS_CI, - IS_WINDOWS, - TEST_WITH_ASAN, - TEST_WITH_ROCM, -) +from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import ( _check_has_dynamic_shape, GPU_TYPE, @@ -21,27 +15,19 @@ ) -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor_codegen_dynamic_shapes yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - importlib.import_module("filelock") # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import ( +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library CommonTemplate, copy_tests, run_and_get_cpp_code, run_and_get_triton_code, TestFailure, ) -from inductor.test_torchinductor_dynamic_shapes import ( +from inductor.test_torchinductor_dynamic_shapes import ( # @manual make_dynamic_cls, test_failures as dynamic_shapes_test_failures, ) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 5f97c3c67490d9..5dee3e2956bae9 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -7,7 +7,7 @@ import sys import unittest from functools import partial -from typing import List +from typing import List, Tuple import torch import torch.library @@ -27,8 +27,7 @@ ) from torch.testing._internal.common_utils import ( IS_ARM64, - IS_CI, - IS_WINDOWS, + IS_FBCODE, parametrize, TEST_CUDA_MEM_LEAK_CHECK, TEST_WITH_ASAN, @@ -37,18 +36,10 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU -if IS_WINDOWS and IS_CI: - sys.stderr.write( - "Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n" - ) - if __name__ == "__main__": - sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") - # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import ( +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library check_model, check_model_gpu, CommonTemplate, @@ -372,7 +363,7 @@ def f(x, r): f(torch.tensor([3], device=device), torch.randn(10, device=device)) - @unittest.expectedFailure + @unittest.skipUnless(IS_FBCODE, "") @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) @@ -570,6 +561,28 @@ def f(x): f(torch.tensor([3], device=device)) + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + @torch._inductor.config.patch(implicit_fallbacks=True) + def test_multi_output_unbacked_custom_op(self, device): + @torch.library.custom_op("test::foo", mutates_args=()) + def foo(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty(2, device=x.device), torch.empty(3, device=x.device) + + @foo.register_fake + def _(x: torch.Tensor) -> torch.Tensor: + ctx = torch.library.get_ctx() + u0 = ctx.new_dynamic_size() + return torch.empty(u0, device=x.device), torch.empty(3, device=x.device) + + @torch.compile(fullgraph=True) + def f(x): + a, b = torch.ops.test.foo(x) + return a.sum() + b.sum() + + f(torch.tensor([3], device=device)) + @torch._inductor.config.patch(disable_cpp_codegen=True) def test_floor(self): # `int(n * 0.2)` will be generated as `floor(0.2*s0)` of torch.SymInt type. diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index fdac0cc9cbea59..24c4392b4e533e 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -26,6 +26,7 @@ ops, skipCPUIf, skipCUDAIf, + skipXPUIf, ) from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( @@ -40,7 +41,7 @@ TEST_WITH_ASAN, TEST_WITH_ROCM, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_XPU from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -49,7 +50,10 @@ try: from .test_torchinductor import check_model, check_model_gpu except ImportError: - from test_torchinductor import check_model, check_model_gpu + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model, + check_model_gpu, + ) except (unittest.SkipTest, ImportError) as e: sys.stderr.write(f"{type(e)}: {e}\n") if __name__ == "__main__": @@ -204,6 +208,8 @@ def format_op(op): inductor_skips["cuda"]["logcumsumexp"] = {f32} inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64} +inductor_skips["xpu"] = {} + inductor_expected_failures_single_sample = defaultdict(dict) inductor_expected_failures_single_sample["cpu"] = { @@ -253,6 +259,109 @@ def format_op(op): }, # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCUDA } +inductor_expected_failures_single_sample["xpu"] = { + "_upsample_bilinear2d_aa": {f16, f32, f64}, + "cholesky": {f32, f64}, + "multinomial": {f16, f32, f64}, + ("normal", "in_place"): {f16, f32, f64}, + ("normal", "number_mean"): {f16, f32, f64}, + "normal": {f16, f32, f64}, + "sparse.sampled_addmm": {f32, f64}, + "tan": {f16}, + "torch.ops.aten._flash_attention_forward": {f16}, + "torch.ops.aten._efficient_attention_forward": {f16, f32}, + "to_sparse": {f16, f32, f64, b8, i32, i64}, + "linalg.eig": {f32, f64}, + "linalg.eigvals": {f32, f64}, + # Double and complex datatype matmul is not supported in oneDNN + "__rmatmul__": {f64}, + ("addmm", "decomposed"): {f64}, + "addr": {f64}, + "baddbmm": {f64}, + "bmm": {f64}, + "byte": {f16, f32}, + "cdist": {f64}, + "corrcoef": {f64}, + "cov": {f64}, + "einsum": {f64}, + "inner": {f64}, + "linalg.cholesky_ex": {f64}, + "linalg.cholesky": {f64}, + ("linalg.det", "singular"): {f64}, + "linalg.ldl_factor_ex": {f64}, + "linalg.ldl_factor": {f64}, + "linalg.ldl_solve": {f64}, + "linalg.matrix_power": {f64}, + "linalg.multi_dot": {f64}, + "matmul": {f64}, + "mm": {f64}, + "mv": {f64}, + "nn.functional.bilinear": {f64}, + "nn.functional.linear": {f64}, + "pca_lowrank": {f64}, + "svd_lowrank": {f64}, + "tensordot": {f64}, + "triangular_solve": {f64}, + "svd": {f64}, + "qr": {f64}, + "pinverse": {f64}, + "ormqr": {f64}, + ("norm", "nuc"): {f64}, + "lu": {f64}, + "lu_solve": {f64}, + "logdet": {f64}, + "linalg.tensorsolve": {f64}, + "linalg.tensorinv": {f64}, + "linalg.svdvals": {f64}, + "linalg.svd": {f64}, + "linalg.solve": {f64}, + "linalg.solve_triangular": {f64}, + "linalg.solve_ex": {f64}, + "linalg.slogdet": {f64}, + "linalg.qr": {f64}, + "linalg.pinv": {f64}, + ("linalg.pinv", "hermitian"): {f64}, + "linalg.norm": {f64}, + ("linalg.norm", "subgradients_at_zero"): {f64}, + "linalg.matrix_rank": {f64}, + ("linalg.matrix_rank", "hermitian"): {f64}, + "linalg.matrix_norm": {f64}, + "linalg.lu": {f64}, + "linalg.lu_solve": {f64}, + "linalg.lu_factor": {f64}, + "linalg.lu_factor_ex": {f64}, + "linalg.lstsq": {f64}, + ("linalg.lstsq", "grad_oriented"): {f64}, + "linalg.inv": {f64}, + "linalg.inv_ex": {f64}, + "linalg.householder_product": {f64}, + "linalg.eigvalsh": {f64}, + "linalg.eigh": {f64}, + "linalg.det": {f64}, + "linalg.cond": {f64}, + "geqrf": {f64}, + "cholesky_solve": {f64}, + "cholesky_inverse": {f64}, + # could not create a primitive + "addbmm": {f16, f32, f64}, + "addmm": {f16, f32, f64}, + "addmv": {f32, f64}, + # could not create a primitive descriptor for + # a deconvolution forward propagation primitive + "nn.functional.conv_transpose2d": {f32, f64}, + "nn.functional.conv_transpose3d": {f32, f64}, + # rrelu not supported on XPU now + "nn.functional.rrelu": {f16, f32, f64}, + "histc": {i32, i64}, + # not implemented for 'Half' + "nn.functional.multilabel_margin_loss": {f16}, + "nn.functional.multi_margin_loss": {f16}, + "nn.functional.avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool3d": {f16}, + # not implemented for 'Bool' + "nn.functional.unfold": {b8}, +} + # intentionally not handled intentionally_not_handled = { @@ -275,11 +384,13 @@ def format_op(op): } inductor_expected_failures_single_sample["cuda"].update(intentionally_not_handled) +inductor_expected_failures_single_sample["xpu"].update(intentionally_not_handled) inductor_gradient_expected_failures_single_sample = defaultdict(dict) inductor_gradient_expected_failures_single_sample["cuda"] = {} +inductor_gradient_expected_failures_single_sample["xpu"] = {} if not TEST_MKL: inductor_expected_failures_single_sample["cpu"].update({}) @@ -287,6 +398,7 @@ def format_op(op): inductor_should_fail_with_exception = defaultdict(dict) inductor_should_fail_with_exception["cpu"] = {} inductor_should_fail_with_exception["cuda"] = {} +inductor_should_fail_with_exception["xpu"] = {} def get_skips_and_xfails(from_dict, xfails=True): @@ -321,9 +433,10 @@ def wrapper_noop_set_seed(op, *args, **kwargs): wrapper_noop_set_seed ) +# key can be either op_name, or (op_name, dtype) +inductor_override_kwargs = defaultdict(dict) -# key can be either op_name, or (op_name, deivce_type), or (op_name, device_type, dtype) -inductor_override_kwargs = { +inductor_override_kwargs["cpu"] = { # the return value of empty is undefined "empty": {"assert_equal": False}, "empty_permuted": {"assert_equal": False}, @@ -332,92 +445,222 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "empty_strided": {"assert_equal": False}, "new_empty_strided": {"assert_equal": False}, "randn": {"assert_equal": False}, - ("cross", "cuda", f16): {"reference_in_float": True}, - ("linalg.cross", "cuda", f16): {"reference_in_float": True}, - ("addr", "cuda", f16): {"reference_in_float": True}, - ("baddbmm", "cuda", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy - ("angle", "cuda", f64): {"reference_in_float": True}, - ("asin", "cuda", f16): {"reference_in_float": True}, - ("atanh", "cuda", f16): {"reference_in_float": True}, - ("cauchy", "cuda"): {"reference_in_float": True}, - ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, - ("cumsum", "cuda", f16): {"reference_in_float": True}, - ("cumprod", "cuda"): {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, - ("logcumsumexp", "cuda"): {"grad_atol": 8e-4, "grad_rtol": 0.001}, - ("exponential", "cuda"): {"reference_in_float": True}, - ("geometric", "cuda"): {"reference_in_float": True}, - ("kron", "cuda", f16): {"reference_in_float": True}, - ("log_normal", "cuda"): {"reference_in_float": True}, - ("masked.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("nn.functional.batch_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.batch_norm.without_cudnn", "cuda", f16): { - "reference_in_float": True - }, - ("nn.functional.cosine_similarity", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.instance_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.local_response_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.normalize", "cuda", f16): {"atol": 1e-3, "rtol": 0.05}, - ("nn.functional.rms_norm", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.soft_margin_loss", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.softmin", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("nn.functional.softsign", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, - ("nn.functional.multilabel_soft_margin_loss", "cpu", f16): { + ("nn.functional.multilabel_soft_margin_loss", f16): { "atol": 3e-4, "rtol": 0.002, }, - ("outer", "cuda", f16): {"reference_in_float": True}, - ("round.decimals_3", "cuda", f16): {"reference_in_float": True}, - ("nn.functional.triplet_margin_loss", "cuda", f16): {"atol": 1e-4, "rtol": 0.02}, - ("nn.functional.triplet_margin_with_distance_loss", "cuda", f16): { - "atol": 1e-4, - "rtol": 0.02, - }, - ("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, - ("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02}, - ("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02}, - ("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02}, - ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, - ("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5}, - ("polygamma.polygamma_n_0", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_1", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_2", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_3", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("polygamma.polygamma_n_4", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4}, - ("special.polygamma.special_polygamma_n_0", "cpu", f32): { + ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, + ("polygamma.polygamma_n_0", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_1", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_2", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_3", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("polygamma.polygamma_n_4", f32): {"atol": 1e-3, "rtol": 1e-4}, + ("special.polygamma.special_polygamma_n_0", f32): { "atol": 1e-3, "rtol": 1e-4, }, - ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True}, - ("uniform", "cuda"): {"reference_in_float": True}, - ("_unsafe_masked_index_put_accumulate", "cuda", f16): {"atol": 1e-4, "rtol": 0.01}, - ("_unsafe_masked_index_put_accumulate", "cpu", f16): {"atol": 1e-4, "rtol": 0.01}, + ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01}, # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors - ("nn.functional.interpolate.bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, - ("nn.functional.upsample_bilinear", "cpu", u8): {"atol": 1, "rtol": 0}, - ("nn.functional.interpolate.bicubic", "cpu", u8): {"atol": 1, "rtol": 0}, + ("nn.functional.interpolate.bilinear", u8): {"atol": 1, "rtol": 0}, + ("nn.functional.upsample_bilinear", u8): {"atol": 1, "rtol": 0}, + ("nn.functional.interpolate.bicubic", u8): {"atol": 1, "rtol": 0}, + # High atol due to precision loss + ("nn.functional.interpolate.bicubic", f32): {"atol": 5e-3, "rtol": 0}, +} + +inductor_override_kwargs["cuda"] = { + # the return value of empty is undefined + "empty": {"assert_equal": False}, + "empty_permuted": {"assert_equal": False}, + "empty_like": {"assert_equal": False}, + "new_empty": {"assert_equal": False}, + "empty_strided": {"assert_equal": False}, + "new_empty_strided": {"assert_equal": False}, + "randn": {"assert_equal": False}, + ("cross", f16): {"reference_in_float": True}, + ("linalg.cross", f16): {"reference_in_float": True}, + ("addr", f16): {"reference_in_float": True}, + ("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy + ("angle", f64): {"reference_in_float": True}, + ("asin", f16): {"reference_in_float": True}, + ("atanh", f16): {"reference_in_float": True}, + "cauchy": {"reference_in_float": True}, + ("cummax", f16): {"atol": 5e-4, "rtol": 0.002}, + ("cumsum", f16): {"reference_in_float": True}, + "cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, + "logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001}, + "exponential": {"reference_in_float": True}, + "geometric": {"reference_in_float": True}, + ("kron", f16): {"reference_in_float": True}, + "log_normal": {"reference_in_float": True}, + ("masked.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("nn.functional.batch_norm", f16): {"reference_in_float": True}, + ("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True}, + ("nn.functional.cosine_similarity", f16): {"reference_in_float": True}, + ("nn.functional.instance_norm", f16): {"reference_in_float": True}, + ("nn.functional.local_response_norm", f16): {"reference_in_float": True}, + ("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05}, + ("nn.functional.rms_norm", f16): {"reference_in_float": True}, + ("nn.functional.soft_margin_loss", f16): {"reference_in_float": True}, + ("nn.functional.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("nn.functional.softsign", f16): {"reference_in_float": True}, + ("nn.functional.tanhshrink", f16): {"atol": 3e-4, "rtol": 0.001}, + ("outer", f16): {"reference_in_float": True}, + ("round.decimals_3", f16): {"reference_in_float": True}, + ("nn.functional.triplet_margin_loss", f16): {"atol": 1e-4, "rtol": 0.02}, + ("nn.functional.triplet_margin_with_distance_loss", f16): { + "atol": 1e-4, + "rtol": 0.02, + }, + ("sinc", f16): {"atol": 0.008, "rtol": 0.002}, + ("torch.ops.aten._safe_softmax.default", f16): {"atol": 5e-4, "rtol": 0.02}, + ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, + ("_softmax_backward_data", f16): {"atol": 0.008, "rtol": 0.002}, + ("special.log_ndtr", f64): {"atol": 1e-6, "rtol": 1e-5}, + ("std_mean.unbiased", f16): {"reference_in_float": True}, + "uniform": {"reference_in_float": True}, + ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-4, "rtol": 0.01}, # High atol due to precision loss - ("nn.functional.interpolate.bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0}, - ("nn.functional.upsample_bilinear", "cuda", f64): {"atol": 5e-4, "rtol": 0}, - ("nn.functional.interpolate.bicubic", "cpu", f32): {"atol": 5e-3, "rtol": 0}, - ("nn.functional.interpolate.bicubic", "cuda", f64): {"atol": 1e-3, "rtol": 0}, + ("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.interpolate.bicubic", f64): {"atol": 1e-3, "rtol": 0}, # Unreasonably high atol requirement: - ("index_reduce.mean", "cuda", f16): {"check_gradient": False}, - ("index_reduce.mean", "cuda", f32): {"check_gradient": False}, - ("index_reduce.mean", "cuda", f64): {"check_gradient": False}, + ("index_reduce.mean", f16): {"check_gradient": False}, + ("index_reduce.mean", f32): {"check_gradient": False}, + ("index_reduce.mean", f64): {"check_gradient": False}, # Gradient contains non-finite entries: - ("index_reduce.amin", "cuda", f64): {"check_gradient": False}, - ("index_reduce.amin", "cuda", f32): {"check_gradient": False}, - ("index_reduce.amin", "cuda", f16): {"check_gradient": False}, - ("index_reduce.amax", "cuda", f64): {"check_gradient": False}, - ("index_reduce.amax", "cuda", f32): {"check_gradient": False}, - ("index_reduce.amax", "cuda", f16): {"check_gradient": False}, - ("tanh", "cuda", f16): {"atol": 1e-4, "rtol": 1e-2}, + ("index_reduce.amin", f64): {"check_gradient": False}, + ("index_reduce.amin", f32): {"check_gradient": False}, + ("index_reduce.amin", f16): {"check_gradient": False}, + ("index_reduce.amax", f64): {"check_gradient": False}, + ("index_reduce.amax", f32): {"check_gradient": False}, + ("index_reduce.amax", f16): {"check_gradient": False}, + ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, } +inductor_override_kwargs["xpu"] = { + # the return value of empty is undefined + "empty": {"assert_equal": False}, + "empty_permuted": {"assert_equal": False}, + "empty_like": {"assert_equal": False}, + "new_empty": {"assert_equal": False}, + "empty_strided": {"assert_equal": False}, + "new_empty_strided": {"assert_equal": False}, + "randn": {"assert_equal": False}, + # XPU + ("cross", f16): {"reference_in_float": True}, + ("addr", f16): {"reference_in_float": True}, + ("baddbmm", f16): {"atol": 2e-3, "rtol": 0.002}, # decomp affects accuracy + ("angle", f64): {"reference_in_float": True}, + ("asin", f16): {"reference_in_float": True}, + ("atanh", f16): {"reference_in_float": True}, + "cauchy": {"reference_in_float": True}, + ("cummax", f16): {"atol": 5e-4, "rtol": 0.002}, + ("cumsum", f16): {"reference_in_float": True}, + "cumprod": {"reference_in_float": True, "atol": 7e-5, "rtol": 0.002}, + ("dot", f16): {"atol": 1e-5, "rtol": 0.002}, + "logcumsumexp": {"grad_atol": 8e-4, "grad_rtol": 0.001}, + "exponential": {"reference_in_float": True}, + "geometric": {"reference_in_float": True}, + ("kron", f16): {"reference_in_float": True}, + ("linalg.cross", f16): {"reference_in_float": True}, + ("linalg.vecdot", f16): {"atol": 1e-5, "rtol": 2e-2}, + "log_normal": {"reference_in_float": True}, + ("logsumexp", f16): {"atol": 1e-5, "rtol": 1e-2}, + ("masked.cumprod", f16): {"atol": 1e-5, "rtol": 5e-2}, + ("masked.cumsum", f16): {"atol": 1e-5, "rtol": 5e-3}, + ("masked.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("masked.softmax", f16): {"atol": 2e-4, "rtol": 0.01}, + ("masked.var", f16): {"atol": 2e-5, "rtol": 5e-3}, + ("native_batch_norm", f64): {"atol": 1e-7, "rtol": 1e-5}, + ("_native_batch_norm_legit", f64): {"atol": 1e-7, "rtol": 5e-6}, + ("_batch_norm_with_update", f64): {"atol": 1e-7, "rtol": 1e-6}, + ("native_layer_norm", f16): {"atol": 5e-3, "rtol": 5e-3}, + ("native_layer_norm", f32): {"atol": 5e-3, "rtol": 5e-3}, + ("nn.functional.batch_norm", f16): {"reference_in_float": True}, + ("nn.functional.batch_norm", f64): {"atol": 1e-6, "rtol": 1e-6}, + ("nn.functional.batch_norm.without_cudnn", f16): {"reference_in_float": True}, + ("nn.functional.conv1d", f16): {"atol": 1e-5, "rtol": 6e-3}, + ("nn.functional.conv3d", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("nn.functional.conv_transpose2d", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("nn.functional.conv_transpose3d", f16): {"atol": 1e-5, "rtol": 5e-3}, + ("nn.functional.cosine_embedding_loss", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("nn.functional.cosine_similarity", f16): { + "reference_in_float": True, + "atol": 1e-5, + "rtol": 5e-3, + }, + ("nn.functional.instance_norm", f16): {"reference_in_float": True}, + ("nn.functional.instance_norm", f64): {"atol": 1e-6, "rtol": 1e-6}, + ("nn.functional.layer_norm", f16): {"atol": 5e-3, "rtol": 2e-3}, + ("nn.functional.layer_norm", f32): {"atol": 5e-5, "rtol": 2e-3}, + ("nn.functional.local_response_norm", f16): {"reference_in_float": True}, + ("nn.functional.multilabel_soft_margin_loss", f16): { + "atol": 3e-4, + "rtol": 2e-3, + }, + ("nn.functional.normalize", f16): {"atol": 1e-3, "rtol": 0.05}, + ("nn.functional.rms_norm", f16): {"reference_in_float": True}, + ("nn.functional.soft_margin_loss", f16): {"reference_in_float": True}, + ("nn.functional.softmin", f16): {"atol": 1e-4, "rtol": 0.01}, + ("nn.functional.softsign", f16): { + "reference_in_float": True, + "atol": 1e-5, + "rtol": 0.005, + }, + ("nn.functional.tanhshrink", f16): {"atol": 3e-4, "rtol": 0.001}, + ("outer", f16): {"reference_in_float": True}, + ("round.decimals_3", f16): {"reference_in_float": True}, + ("nn.functional.triplet_margin_loss", f16): {"atol": 1e-4, "rtol": 0.02}, + ("nn.functional.triplet_margin_with_distance_loss", f16): { + "atol": 1e-4, + "rtol": 0.02, + }, + ("remainder", f16): {"atol": 1e-4, "rtol": 0.005}, + ("nn.functional.upsample_bilinear", f16): {"atol": 1e-5, "rtol": 0.002}, + ("sinc", f16): {"atol": 0.008, "rtol": 0.002}, + ("softmax", f16): {"atol": 1e-4, "rtol": 0.02}, + ("_softmax_backward_data", f16): {"atol": 0.008, "rtol": 0.002}, + ("special.log_ndtr", f64): {"atol": 1e-6, "rtol": 1e-5}, + ("std_mean.unbiased", f16): { + "reference_in_float": True, + "atol": 5e-5, + "rtol": 5e-3, + }, + ("trapezoid", f16): {"atol": 1e-5, "rtol": 5e-3}, + ("trapz", f16): {"atol": 1e-5, "rtol": 5e-3}, + "uniform": {"reference_in_float": True}, + ("var_mean", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("var_mean.unbiased", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("vdot", f16): {"atol": 1e-5, "rtol": 2e-3}, + # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors + # High atol due to precision loss + ("nn.functional.interpolate.bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.upsample_bilinear", f64): {"atol": 5e-4, "rtol": 0}, + ("nn.functional.interpolate.bicubic", f64): {"atol": 1e-3, "rtol": 0}, + # Unreasonably high atol requirement: + ("index_reduce.mean", f16): {"check_gradient": False}, + ("index_reduce.mean", f32): {"check_gradient": False}, + ("index_reduce.mean", f64): {"check_gradient": False}, + # Gradient contains non-finite entries: + ("index_reduce.amin", f64): {"check_gradient": False}, + ("index_reduce.amin", f32): {"check_gradient": False}, + ("index_reduce.amin", f16): {"check_gradient": False}, + ("index_reduce.amax", f64): {"check_gradient": False}, + ("index_reduce.amax", f32): {"check_gradient": False}, + ("index_reduce.amax", f16): {"check_gradient": False}, + ("tanh", f16): {"atol": 1e-4, "rtol": 1e-2}, + ("nn.functional.embedding_bag", f16): {"check_gradient": False}, + ("nn.functional.embedding_bag", f32): {"check_gradient": False}, + ("nn.functional.embedding_bag", f64): {"check_gradient": False}, + ("_unsafe_masked_index", f16): {"atol": 1e-5, "rtol": 2e-3}, + ("_unsafe_masked_index_put_accumulate", f16): {"atol": 1e-5, "rtol": 5e-3}, +} # Test with one sample only for following ops -inductor_one_sample = { +inductor_one_sample = defaultdict(dict) + +inductor_one_sample["cpu"] = { "_segment_reduce.lengths": {f16}, "_segment_reduce.offsets": {f16}, "addmv": {f16}, @@ -453,89 +696,244 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "normal": {f16, f32, f64}, "put": {f16, f32, f64}, "take": {b8, f16, f32, f64, i32, i64}, - ("__rdiv__", "cuda"): {f16}, - ("__rmod__", "cuda"): {f16, i64}, - ("__rmul__", "cuda"): {f16}, - ("__rpow__", "cuda"): {f16}, - ("_unsafe_masked_index", "cuda"): {f16}, - ("_unsafe_masked_index_put_accumulate", "cuda"): {f16}, - ("addcdiv", "cuda"): {f16}, - ("addcmul", "cuda"): {f16}, - ("atan2", "cuda"): {f16}, - ("cumsum", "cuda"): {f16}, - ("cumulative_trapezoid", "cuda"): {f16}, - ("dist", "cuda"): {f16}, - ("div.no_rounding_mode", "cuda"): {f16}, - ("fmod", "cuda"): {f16}, - ("grid_sampler_2d", "cuda"): {f16}, - ("index_fill", "cuda"): {f16, f32, f64}, - ("ldexp", "cuda"): {f16}, - ("lerp", "cuda"): {f16}, - ("linalg.householder_product", "cuda"): {f32}, - ("linalg.matrix_norm", "cuda"): {f16}, - ("linalg.vector_norm", "cuda"): {f16}, - ("logspace", "cuda"): {i32, i64}, - ("masked.cumsum", "cuda"): {f16}, - ("masked.logsumexp", "cuda"): {f16}, - ("masked.mean", "cuda"): {b8}, - ("masked.normalize", "cuda"): {f16}, - ("masked.prod", "cuda"): {f16}, - ("masked.std", "cuda"): {f16}, - ("masked.var", "cuda"): {f16}, - ("mul", "cuda"): {f16}, - ("nn.functional.alpha_dropout", "cuda"): {f16, f32, f64}, - ("nn.functional.avg_pool1d", "cuda"): {f16, f32, f64}, - ("nn.functional.avg_pool2d", "cuda"): {f16, f32, f64}, - ("nn.functional.avg_pool3d", "cuda"): {f16, f32, f64}, - ("nn.functional.binary_cross_entropy", "cuda"): {f16}, - ("nn.functional.binary_cross_entropy_with_logits", "cuda"): {f16}, - ("nn.functional.conv2d", "cuda"): {f16}, - ("nn.functional.cosine_embedding_loss", "cuda"): {f16}, - ("nn.functional.dropout2d", "cuda"): {f16, f32, f64}, - ("nn.functional.dropout3d", "cuda"): {f16, f32, f64}, - ("nn.functional.dropout", "cuda"): {f16, f32, f64}, - ("nn.functional.feature_alpha_dropout.with_train", "cuda"): {f16, f32, f64}, - ("nn.functional.fractional_max_pool2d", "cuda"): {f16, f32, f64}, - ("nn.functional.fractional_max_pool3d", "cuda"): {f16, f32, f64}, - ("nn.functional.grid_sample", "cuda"): {f16}, - ("nn.functional.group_norm", "cuda"): {f16}, - ("nn.functional.hinge_embedding_loss", "cuda"): {f16}, +} + +inductor_one_sample["cuda"] = { + "_segment_reduce.lengths": {f16}, + "_segment_reduce.offsets": {f16}, + "addmv": {f16}, + "as_strided.partial_views": {f16}, + "corrcoef": {f16}, + "diff": {f16}, + "einsum": {f16, i32}, + "gradient": {f16}, + "histogram": {f32, f64}, + "histogramdd": {f32, f64}, + "index_put": {f16, f32, f64}, + "linalg.eig": {f32, f64}, + "linspace": {f16, i32, i64}, + "linspace.tensor_overload": {f16, f32, f64, i32, i64}, + "logspace": {f16, i32, i64}, + "logspace.tensor_overload": {f16, f32, f64, i32, i64}, + "masked_logsumexp": {i64}, + "max_pool2d_with_indices_backward": {f16, f32, f64}, + "new_empty_strided": {f16}, + "nn.functional.adaptive_avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool1d": {f16, f32}, + "nn.functional.adaptive_max_pool2d": {f16, f32}, + "nn.functional.bilinear": {f16}, + "nn.functional.conv_transpose1d": {f16}, + "nn.functional.conv_transpose2d": {f16}, + "nn.functional.conv_transpose3d": {f16}, + "nn.functional.cosine_similarity": {f16}, + "nn.functional.cross_entropy": {f16, f32, f64}, + "nn.functional.gaussian_nll_loss": {f16}, + "nn.functional.grid_sample": {f16, f32, f64}, + "nn.functional.interpolate.area": {f16}, + "nn.functional.nll_loss": {f16, f32, f64}, + "normal": {f16, f32, f64}, + "put": {f16, f32, f64}, + "take": {b8, f16, f32, f64, i32, i64}, + "__rdiv__": {f16}, + "__rmod__": {f16, i64}, + "__rmul__": {f16}, + "__rpow__": {f16}, + "_unsafe_masked_index": {f16}, + "_unsafe_masked_index_put_accumulate": {f16}, + "addcdiv": {f16}, + "addcmul": {f16}, + "atan2": {f16}, + "cumsum": {f16}, + "cumulative_trapezoid": {f16}, + "dist": {f16}, + "div.no_rounding_mode": {f16}, + "fmod": {f16}, + "grid_sampler_2d": {f16}, + "index_fill": {f16, f32, f64}, + "ldexp": {f16}, + "lerp": {f16}, + "linalg.householder_product": {f32}, + "linalg.matrix_norm": {f16}, + "linalg.vector_norm": {f16}, + "masked.cumsum": {f16}, + "masked.logsumexp": {f16}, + "masked.mean": {b8}, + "masked.normalize": {f16}, + "masked.prod": {f16}, + "masked.std": {f16}, + "masked.var": {f16}, + "mul": {f16}, + "nn.functional.alpha_dropout": {f16, f32, f64}, + "nn.functional.avg_pool1d": {f16, f32, f64}, + "nn.functional.avg_pool2d": {f16, f32, f64}, + "nn.functional.avg_pool3d": {f16, f32, f64}, + "nn.functional.binary_cross_entropy": {f16}, + "nn.functional.binary_cross_entropy_with_logits": {f16}, + "nn.functional.conv2d": {f16}, + "nn.functional.cosine_embedding_loss": {f16}, + "nn.functional.dropout2d": {f16, f32, f64}, + "nn.functional.dropout3d": {f16, f32, f64}, + "nn.functional.dropout": {f16, f32, f64}, + "nn.functional.feature_alpha_dropout.with_train": {f16, f32, f64}, + "nn.functional.fractional_max_pool2d": {f16, f32, f64}, + "nn.functional.fractional_max_pool3d": {f16, f32, f64}, + "nn.functional.group_norm": {f16}, + "nn.functional.hinge_embedding_loss": {f16}, + # Enabling all tests for this test fails randomly + # See https://github.com/pytorch/pytorch/issues/129238 + "nn.functional.huber_loss": {f16}, + "nn.functional.interpolate.bicubic": {f16}, + "nn.functional.interpolate.bilinear": {f16}, + "nn.functional.interpolate.trilinear": {f16}, + "nn.functional.kl_div": {f16}, + "nn.functional.margin_ranking_loss": {f16}, + "nn.functional.max_pool1d": {f16, f32, f64}, + "nn.functional.max_pool3d": {f16}, + "nn.functional.mse_loss": {f16}, + "nn.functional.multi_margin_loss": {f16}, + "nn.functional.multilabel_margin_loss": {f16}, + "nn.functional.multilabel_soft_margin_loss": {f16}, + "nn.functional.normalize": {f16}, + "nn.functional.pad.replicate": {f16, f32, f64}, + "nn.functional.pad.reflect": {f16}, + "nn.functional.pairwise_distance": {f16}, + "nn.functional.poisson_nll_loss": {f16}, + "nn.functional.rms_norm": {f16}, + "norm": {f16}, + "pow": {f16}, + "prod": {f16}, + "scatter_reduce.amax": {f16, f32, f64}, + "scatter_reduce.amin": {f16, f32, f64}, + "scatter_reduce.mean": {f16, f32, f64}, + "special.xlog1py": {f16}, + "std": {f16}, + "std_mean": {f16}, + "svd_lowrank": {f32, f64}, + "trapezoid": {f16}, + "trapz": {f16}, + "true_divide": {f16}, + "var": {f16}, + "var_mean": {f16}, + "xlogy": {f16}, +} + +inductor_one_sample["xpu"] = { + "_segment_reduce.lengths": {f16}, + "_segment_reduce.offsets": {f16}, + "addmv": {f16}, + "as_strided.partial_views": {f16}, + "corrcoef": {f16}, + "diff": {f16}, + "einsum": {f16, i32}, + "gradient": {f16}, + "histogram": {f32, f64}, + "histogramdd": {f32, f64}, + "index_put": {f16, f32, f64}, + "linalg.eig": {f32, f64}, + "linspace": {f16, i32, i64}, + "linspace.tensor_overload": {f16, f32, f64, i32, i64}, + "logspace": {f16, i32, i64}, + "logspace.tensor_overload": {f16, f32, f64, i32, i64}, + "masked_logsumexp": {i64}, + "max_pool2d_with_indices_backward": {f16, f32, f64}, + "new_empty_strided": {f16}, + "nn.functional.adaptive_avg_pool3d": {f16}, + "nn.functional.adaptive_max_pool1d": {f16, f32}, + "nn.functional.adaptive_max_pool2d": {f16, f32}, + "nn.functional.bilinear": {f16}, + "nn.functional.conv_transpose1d": {f16}, + "nn.functional.conv_transpose2d": {f16}, + "nn.functional.conv_transpose3d": {f16}, + "nn.functional.cosine_similarity": {f16}, + "nn.functional.cross_entropy": {f16, f32, f64}, + "nn.functional.gaussian_nll_loss": {f16}, + "nn.functional.grid_sample": {f16, f32, f64}, + "nn.functional.interpolate.area": {f16}, + "nn.functional.nll_loss": {f16, f32, f64}, + "normal": {f16, f32, f64}, + "put": {f16, f32, f64}, + "take": {b8, f16, f32, f64, i32, i64}, + "__rdiv__": {f16}, + "__rmod__": {f16, i64}, + "__rmul__": {f16}, + "__rpow__": {f16}, + "_unsafe_masked_index": {f16}, + "_unsafe_masked_index_put_accumulate": {f16}, + "addcdiv": {f16}, + "addcmul": {f16}, + "atan2": {f16}, + "cumsum": {f16}, + "cumulative_trapezoid": {f16}, + "dist": {f16}, + "div.no_rounding_mode": {f16}, + "fmod": {f16}, + "grid_sampler_2d": {f16}, + "index_fill": {f16, f32, f64}, + "ldexp": {f16}, + "lerp": {f16}, + "linalg.householder_product": {f32}, + "linalg.matrix_norm": {f16}, + "linalg.vector_norm": {f16}, + "masked.cumsum": {f16}, + "masked.logsumexp": {f16}, + "masked.mean": {b8}, + "masked.normalize": {f16}, + "masked.prod": {f16}, + "masked.std": {f16}, + "masked.var": {f16}, + "mul": {f16}, + "nn.functional.alpha_dropout": {f16, f32, f64}, + "nn.functional.avg_pool1d": {f16, f32, f64}, + "nn.functional.avg_pool2d": {f16, f32, f64}, + "nn.functional.avg_pool3d": {f16, f32, f64}, + "nn.functional.binary_cross_entropy": {f16}, + "nn.functional.binary_cross_entropy_with_logits": {f16}, + "nn.functional.conv2d": {f16}, + "nn.functional.cosine_embedding_loss": {f16}, + "nn.functional.dropout2d": {f16, f32, f64}, + "nn.functional.dropout3d": {f16, f32, f64}, + "nn.functional.dropout": {f16, f32, f64}, + "nn.functional.feature_alpha_dropout.with_train": {f16, f32, f64}, + "nn.functional.fractional_max_pool2d": {f16, f32, f64}, + "nn.functional.fractional_max_pool3d": {f16, f32, f64}, + "nn.functional.group_norm": {f16}, + "nn.functional.hinge_embedding_loss": {f16}, # Enabling all tests for this test fails randomly # See https://github.com/pytorch/pytorch/issues/129238 - ("nn.functional.huber_loss", "cuda"): {f16}, - ("nn.functional.interpolate.bicubic", "cuda"): {f16}, - ("nn.functional.interpolate.bilinear", "cuda"): {f16}, - ("nn.functional.interpolate.trilinear", "cuda"): {f16}, - ("nn.functional.kl_div", "cuda"): {f16}, - ("nn.functional.margin_ranking_loss", "cuda"): {f16}, - ("nn.functional.max_pool1d", "cuda"): {f16, f32, f64}, - ("nn.functional.max_pool3d", "cuda"): {f16}, - ("nn.functional.mse_loss", "cuda"): {f16}, - ("nn.functional.multi_margin_loss", "cuda"): {f16}, - ("nn.functional.multilabel_margin_loss", "cuda"): {f16}, - ("nn.functional.multilabel_soft_margin_loss", "cuda"): {f16}, - ("nn.functional.normalize", "cuda"): {f16}, - ("nn.functional.pad.replicate", "cuda"): {f16, f32, f64}, - ("nn.functional.pad.reflect", "cuda"): {f16}, - ("nn.functional.pairwise_distance", "cuda"): {f16}, - ("nn.functional.poisson_nll_loss", "cuda"): {f16}, - ("nn.functional.rms_norm", "cuda"): {f16}, - ("norm", "cuda"): {f16}, - ("pow", "cuda"): {f16}, - ("prod", "cuda"): {f16}, - ("scatter_reduce.amax", "cuda"): {f16, f32, f64}, - ("scatter_reduce.amin", "cuda"): {f16, f32, f64}, - ("scatter_reduce.mean", "cuda"): {f16, f32, f64}, - ("special.xlog1py", "cuda"): {f16}, - ("std", "cuda"): {f16}, - ("std_mean", "cuda"): {f16}, - ("svd_lowrank", "cuda"): {f32, f64}, - ("trapezoid", "cuda"): {f16}, - ("trapz", "cuda"): {f16}, - ("true_divide", "cuda"): {f16}, - ("var", "cuda"): {f16}, - ("var_mean", "cuda"): {f16}, - ("xlogy", "cuda"): {f16}, + "nn.functional.huber_loss": {f16}, + "nn.functional.interpolate.bicubic": {f16}, + "nn.functional.interpolate.bilinear": {f16}, + "nn.functional.interpolate.trilinear": {f16}, + "nn.functional.kl_div": {f16}, + "nn.functional.margin_ranking_loss": {f16}, + "nn.functional.max_pool1d": {f16, f32, f64}, + "nn.functional.max_pool3d": {f16}, + "nn.functional.mse_loss": {f16}, + "nn.functional.multi_margin_loss": {f16}, + "nn.functional.multilabel_margin_loss": {f16}, + "nn.functional.multilabel_soft_margin_loss": {f16}, + "nn.functional.normalize": {f16}, + "nn.functional.pad.replicate": {f16, f32, f64}, + "nn.functional.pad.reflect": {f16}, + "nn.functional.pairwise_distance": {f16}, + "nn.functional.poisson_nll_loss": {f16}, + "nn.functional.rms_norm": {f16}, + "norm": {f16}, + "pow": {f16}, + "prod": {f16}, + "scatter_reduce.amax": {f16, f32, f64}, + "scatter_reduce.amin": {f16, f32, f64}, + "scatter_reduce.mean": {f16, f32, f64}, + "special.xlog1py": {f16}, + "std": {f16}, + "std_mean": {f16}, + "svd_lowrank": {f32, f64}, + "trapezoid": {f16}, + "trapz": {f16}, + "true_divide": {f16}, + "var": {f16}, + "var_mean": {f16}, + "xlogy": {f16}, } @@ -569,6 +967,7 @@ def tearDown(self): True ) # inductor kernels failing this test intermittently @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipXPUIf(not HAS_XPU, "Skipped! Supported XPU compiler not found") @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @skipIfTorchDynamo("Test uses dynamo already") @@ -590,6 +989,8 @@ def test_comprehensive(self, device, dtype, op): # TODO: should we move empty_cache to the common device interface if device_type == "cuda": torch.cuda.empty_cache() + elif device == "xpu": + torch.xpu.empty_cache() op_name = op.name if op.variant_test_name: op_name += f".{op.variant_test_name}" @@ -627,12 +1028,12 @@ def test_comprehensive(self, device, dtype, op): test_expect = ExpectedTestResult.SUCCESS overridden_kwargs = {} - if op_name in inductor_override_kwargs: - overridden_kwargs = inductor_override_kwargs[op_name] - elif (op_name, device_type) in inductor_override_kwargs: - overridden_kwargs = inductor_override_kwargs[(op_name, device_type)] - elif (op_name, device_type, dtype) in inductor_override_kwargs: - overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)] + overridden_kwargs.update( + inductor_override_kwargs.get(device_type, {}).get(op_name, {}) + ) + overridden_kwargs.update( + inductor_override_kwargs.get(device_type, {}).get((op_name, dtype), {}) + ) func = op.get_op() def fn(*args, **kwargs): @@ -650,8 +1051,7 @@ def fn(*args, **kwargs): samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) if ( - dtype in inductor_one_sample.get(op_name, {}) - or dtype in inductor_one_sample.get((op_name, device_type), {}) + dtype in inductor_one_sample.get(device_type, {}).get(op_name, {}) ) and not ALL_SAMPLES: if isinstance(samples, (list, tuple)): samples = [samples[0]] @@ -793,7 +1193,7 @@ def _get_tolerances(dtype): # print(f"SUCCEEDED OP {op_name} on {device_type} with {dtype}", flush=True, file=f) -instantiate_device_type_tests(TestInductorOpInfo, globals()) +instantiate_device_type_tests(TestInductorOpInfo, globals(), allow_xpu=True) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 33782d76b7f934..16424fe1b4f008 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -241,7 +241,7 @@ def get_input(view_size: Tuple[int]) -> torch.Tensor: ((3 * max_block, 2), 3, 2), # Multiple of max block. Uses loops. ( (2, 3 * max_block), - 3, + 2, 2, ), # Multiple of max block. Uses loops. ((128, 128), 3, 2), # Test a large size, with loops. diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index b6e04bf99220bf..3d3fc29f3b398d 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -10,8 +10,10 @@ try: - from extension_backends.triton.device_interface import DeviceInterface - from extension_backends.triton.extension_codegen_backend import ( + from extension_backends.triton.device_interface import ( + DeviceInterface, # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend + ) + from extension_backends.triton.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950 CPUDeviceOpOverrides, ExtensionScheduling, ExtensionWrapperCodegen, @@ -41,7 +43,7 @@ try: from . import test_torchinductor except ImportError: - import test_torchinductor + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library except unittest.SkipTest: if __name__ == "__main__": sys.exit(0) diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index bcb0584b18b0fe..24f322dfebb845 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -9,15 +9,21 @@ try: - import triton # noqa: F401 + import triton # noqa: F401 # @manual + import triton.language as tl # @manual except ImportError: if __name__ == "__main__": sys.exit(0) raise unittest.SkipTest("requires triton") # noqa: B904 from torch._inductor import config -from torch._inductor.runtime.hints import TRITON_MAX_BLOCK -from torch._inductor.runtime.triton_heuristics import triton_config +from torch._inductor.runtime.hints import ( + DeviceProperties, + HeuristicType, + TRITON_MAX_BLOCK, +) +from torch._inductor.runtime.triton_helpers import math as tl_math +from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config from torch._inductor.test_case import run_tests, TestCase @@ -81,6 +87,59 @@ def test_artificial_zgrid(self): def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() + def _get_cos_kernel_caching_autotuner_args(self): + from triton.compiler.compiler import AttrsDescriptor # @manual + + @triton.jit + def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): + xnumel = 16 + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + tmp1 = tl_math.cos(tmp0) + tl.store(out_ptr0 + (x0), tmp1, xmask) + + triton_meta = { + "signature": {0: "*fp32", 1: "*fp32", 2: "i32"}, + "device": DeviceProperties.create(torch.device("cuda")), + "constants": {}, + "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], + } + + configs = [ + triton_config([16], 64), + triton_config([256], 64), + ] + + inductor_meta = {} + + return { + "fn": triton_, + "triton_meta": triton_meta, + "configs": configs, + "save_cache_hook": False, + "mutated_arg_names": [], + "heuristic_type": HeuristicType.POINTWISE, + "inductor_meta": inductor_meta, + } + + @skipIfXpu + def test_pre_hook_assert(self): + # assert if any of the configs passed to the CachingAutotuner have pre-hooks + args = self._get_cos_kernel_caching_autotuner_args() + + def pre_hook(kwargs): + if "in_ptr0" in kwargs: + kwargs["in_ptr0"].zero_() + + for cfg in args["configs"]: + cfg.pre_hook = pre_hook + + with self.assertRaisesRegex(AssertionError, "pre_hook"): + autotuner = CachingAutotuner(**args) + if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index e2a659c0cc190a..c11bfbcb790c70 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -16,8 +16,14 @@ from torch._inductor.utils import run_and_get_code from torch._library import capture_triton from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + parametrize, + skipIfRocm, + skipIfXpu, + TEST_WITH_ROCM, +) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU +from torch.testing._internal.logging_utils import logs_to_string # Defines all the kernels for tests from torch.testing._internal.triton_utils import * # noqa: F403 @@ -30,12 +36,12 @@ if not TEST_WITH_ROCM: if HAS_CUDA: - from triton.language.extra.cuda.libdevice import ( + from triton.language.extra.cuda.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, ) elif HAS_XPU: - from triton.language.extra.intel.libdevice import ( + from triton.language.extra.intel.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, ) @@ -266,6 +272,57 @@ def call_triton_return_view(x: torch.Tensor): self.assertEqual(2 * t_view, compiled_func(t).view(16)) self.assertEqual(2 * t, compiled_func(t)) + @requires_gpu + def test_no_nan_kernels(self): + @triton.jit + def add_one_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x + 1 + tl.store(out_ptr + offsets, output, mask=mask) + + def add_one(x, out): + n_elements = x.numel() + add_one_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4) + + class AddOne(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + out = torch.empty_like(x) + add_one(x, out) + ctx.save_for_backward(out) + return out + + @staticmethod + def backward(ctx, grad): + (saved,) = ctx.saved_tensors + out = torch.empty_like(grad) + add_one(saved, out) + return out + + @torch.compile + def f(x): + return AddOne.apply(x) + + log_stream, ctx = logs_to_string("torch._inductor.codecache", "output_code") + + x = torch.randn(3, requires_grad=True, device=GPU_TYPE) + with ctx(): + y = f(x) + + output_code = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() + self.assertTrue(len(output_code) > 0, msg="output code is not empty") + self.assertEqual(output_code.count('float("nan")'), 0) + self.assertEqual(output_code.count("float('nan')"), 0) + @requires_gpu @common_utils.parametrize("grad_fn", [torch.no_grad, torch.enable_grad]) @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) @@ -958,6 +1015,66 @@ def f(inp): compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) + @torch._inductor.config.patch( + triton_kernel_default_layout_constraint="needs_fixed_stride_order" + ) + @requires_gpu + def test_layout_constraint_needs_fixed_stride_order(self): + # Construct a custom op whose output strides are (1, 2) + @torch.library.custom_op("mylib::weird_op_with_lowering", mutates_args={}) + def weird_op_with_lowering(x: torch.Tensor) -> torch.Tensor: + return torch.empty_strided((2, 2), (1, 2), dtype=x.dtype, device=x.device) + + @weird_op_with_lowering.register_fake + def _(x): + return torch.empty_strided((2, 2), (1, 2), dtype=x.dtype, device=x.device) + + # The lowering for the custom op produces output strides (2, 1). + from torch._inductor.lowering import empty_strided, register_lowering + + @register_lowering(torch.ops.mylib.weird_op_with_lowering) + def _(x): + return empty_strided( + x.shape, (2, 1), dtype=x.dtype, device=torch.device(GPU_TYPE, 0) + ) + + # Triton kernel that has different behavior depending on the input strides. + @triton.jit + def kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + output = offsets + tl.store(out_ptr + offsets, output, mask=mask) + + def arange_out(x, out): + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + kernel[grid](x, out, n_elements, BLOCK_SIZE=4) + + def f(x): + y = weird_op_with_lowering(x) + # Inductor lowering will decide that y is better having strides (2, 1). + # This is different from the strides at tracing time (1, 2). + # Under the "needs_fixed_stride_order" config, inductor will coerce + # y to have strides (1, 2) before passing it to arange_out. + # If it doesn't, then the result will be different from eager mode. + arange_out(x, y) + return x + y + + x = torch.randn(2, 2, device=GPU_TYPE) + eager_out = f(x) + + compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True) + compiled_inductor_out = compiled_inductor_f(x) + self.assertEqual(compiled_inductor_out, eager_out) + @requires_gpu def test_triton_kernel_strided_input_nonzero_offset(self): def f(inp): @@ -1493,6 +1610,18 @@ def f(x, y): x = torch.randn(4, device=GPU_TYPE) f(x, x) + @requires_gpu + @parametrize("dtype", (torch.float16, torch.float32, torch.float64)) + def test_triton_kernel_float64_constant(self, dtype): + def f(x): + return x * (0.12 * x.shape[0]) + + x = torch.ones(200, device=GPU_TYPE, dtype=dtype) + + eager_out = f(x) + compiled_out = torch.compile(f, dynamic=True)(x) + self.assertEqual(compiled_out, eager_out) + def make_mutation_test(fn): @requires_gpu @@ -2414,8 +2543,8 @@ def f(x, y): @requires_gpu def test_capture_triton_disabled_in_triton_op(self): - import triton - import triton.language as tl + import triton # @manual + import triton.language as tl # @manual @triton.jit def add_kernel( diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 14668d51b1137e..0ef2e6131166c8 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -9,9 +9,12 @@ from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_big_gpu from torch.testing import make_tensor -from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + skipCUDAIf, +) from torch.testing._internal.common_utils import IS_LINUX, parametrize -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, skipCUDAIf +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA class TestUnbackedSymints(InductorTestCase): @@ -194,6 +197,9 @@ def test_vertical_pointwise_reduction_fusion(self, device): # reset in case we run both cpu and cuda tests torch._inductor.metrics.reset() + if device == "cpu": + raise unittest.SkipTest("This test requires cuda") + # Tests fusing a pointwise & reduction op with unbacked numel/rnumel. def fn(x, y, repeats): u0 = repeats.item() diff --git a/test/inductor/test_xpu_basic.py b/test/inductor/test_xpu_basic.py index acc197c35f2104..f4bf30e4f2d7d8 100644 --- a/test/inductor/test_xpu_basic.py +++ b/test/inductor/test_xpu_basic.py @@ -20,7 +20,10 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import check_model_gpu, TestCase +from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + check_model_gpu, + TestCase, +) # TODO: Remove this file. diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index ac3d8e40b05529..9cbbacb087103f 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -46,7 +46,6 @@ def removeExceptions(graph): n.destroy() -@skipIfTorchDynamo("somehow causing hanging during python shutdown") class TestFreezing(JitTestCase): def test_freeze_module(self): class M(nn.Module): diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py index 40dca90b164887..c0419664d0098e 100644 --- a/test/nn/test_multihead_attention.py +++ b/test/nn/test_multihead_attention.py @@ -17,6 +17,7 @@ instantiate_parametrized_tests, parametrize as parametrize_test, run_tests, + skipIfRocm, TEST_NUMPY, TEST_WITH_CROSSREF, ) @@ -745,6 +746,7 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self): class TestMultiheadAttentionNNDeviceType(NNTestCase): + @skipIfRocm(msg="To investigate: yields NaN") def test_multihead_self_attn_two_masks_fast_path(self, device): """ Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 46980f2ce64aaf..1d280dfa03450b 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -51,17 +51,19 @@ def tearDown(self): OrtBackend.clear_cached_instances() def test_get_ort_device_type(self): + from onnxruntime.capi import _pybind_state as ORTC + self.assertEqual( torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"), - torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cuda(), + ORTC.OrtDevice.cuda(), ) self.assertEqual( torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"), - torch.onnx._internal.onnxruntime.ORTC.OrtDevice.cpu(), + ORTC.OrtDevice.cpu(), ) self.assertEqual( torch.onnx._internal.onnxruntime._get_ort_device_type("maia"), - torch.onnx._internal.onnxruntime.ORTC.OrtDevice.npu(), + ORTC.OrtDevice.npu(), ) def test_torch_compile_backend_registration(self): diff --git a/test/onnx/dynamo/test_exporter_api.py b/test/onnx/dynamo/test_exporter_api.py index c6e8cd47cfe24d..e2c8cf3df0b6c3 100644 --- a/test/onnx/dynamo/test_exporter_api.py +++ b/test/onnx/dynamo/test_exporter_api.py @@ -1,16 +1,11 @@ # Owner(s): ["module: onnx"] import io -import os import onnx import torch from torch.onnx import dynamo_export, ExportOptions, ONNXProgram -from torch.onnx._internal import _exporter_legacy -from torch.onnx._internal._exporter_legacy import ( - ONNXProgramSerializer, - ResolvedExportOptions, -) +from torch.onnx._internal._exporter_legacy import ResolvedExportOptions from torch.testing._internal import common_utils @@ -75,79 +70,6 @@ def test_save_to_existing_buffer_default_serializer(self): dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer) onnx.load(buffer) - def test_save_to_file_using_specified_serializer(self): - expected_buffer = "I am not actually ONNX" - - class CustomSerializer(ONNXProgramSerializer): - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - destination.write(expected_buffer.encode()) - - with common_utils.TemporaryFileName() as path: - dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save( - path, serializer=CustomSerializer() - ) - with open(path) as fp: - self.assertEqual(fp.read(), expected_buffer) - - def test_save_to_file_using_specified_serializer_without_inheritance(self): - expected_buffer = "I am not actually ONNX" - - # NOTE: Inheritance from `ONNXProgramSerializer` is not required. - # Because `ONNXProgramSerializer` is a Protocol class. - class CustomSerializer: - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - destination.write(expected_buffer.encode()) - - with common_utils.TemporaryFileName() as path: - dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save( - path, serializer=CustomSerializer() - ) - with open(path) as fp: - self.assertEqual(fp.read(), expected_buffer) - - def test_save_sarif_log_to_file_with_successful_export(self): - with common_utils.TemporaryFileName(suffix=".sarif") as path: - dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path) - self.assertTrue(os.path.exists(path)) - - def test_save_sarif_log_to_file_with_failed_export(self): - class ModelWithExportError(torch.nn.Module): - def forward(self, x): - raise RuntimeError("Export error") - - with self.assertRaises(RuntimeError): - dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - self.assertTrue( - os.path.exists(_exporter_legacy._DEFAULT_FAILED_EXPORT_SARIF_LOG_PATH) - ) - - def test_onnx_program_accessible_from_exception_when_export_failed(self): - class ModelWithExportError(torch.nn.Module): - def forward(self, x): - raise RuntimeError("Export error") - - with self.assertRaises(torch.onnx.OnnxExporterError) as cm: - dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - self.assertIsInstance(cm.exception, torch.onnx.OnnxExporterError) - self.assertIsInstance(cm.exception.onnx_program, ONNXProgram) - - def test_access_onnx_program_model_proto_raises_when_onnx_program_is_emitted_from_failed_export( - self, - ): - class ModelWithExportError(torch.nn.Module): - def forward(self, x): - raise RuntimeError("Export error") - - with self.assertRaises(torch.onnx.OnnxExporterError) as cm: - dynamo_export(ModelWithExportError(), torch.randn(1, 1, 2)) - onnx_program = cm.exception.onnx_program - with self.assertRaises(RuntimeError): - onnx_program.model_proto - def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_is_true( self, ): diff --git a/test/onnx/dynamo/test_registry_dispatcher.py b/test/onnx/dynamo/test_registry_dispatcher.py index d782f94dd311d9..8605b098d30055 100644 --- a/test/onnx/dynamo/test_registry_dispatcher.py +++ b/test/onnx/dynamo/test_registry_dispatcher.py @@ -8,18 +8,11 @@ import onnxscript # type: ignore[import] from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import] -from onnxscript.function_libs.torch_lib import ops # type: ignore[import] from onnxscript.onnx_opset import opset15 as op # type: ignore[import] import torch import torch.fx -from torch.onnx._internal.diagnostics import infra -from torch.onnx._internal.fx import ( - analysis, - diagnostics, - onnxfunction_dispatcher, - registration, -) +from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher, registration from torch.testing._internal import common_utils @@ -84,60 +77,6 @@ def test_custom(x, y): [test_original, test_custom], ) - def test_unsupported_nodes_analysis_with_missing_aten_op(self): - # NOTE: simulate unsupported nodes - aten_mul_tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="mul", overload="Tensor" - ) - aten_mul_default = registration.OpName.from_name_parts( - namespace="aten", op_name="mul" - ) - aten_add_tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="add", overload="Tensor" - ) - aten_add_default = registration.OpName.from_name_parts( - namespace="aten", op_name="add" - ) - - self.registry._registry.pop(aten_mul_tensor) - self.registry._registry.pop(aten_mul_default) - self.registry._registry.pop(aten_add_tensor) - self.registry._registry.pop(aten_add_default) - - diagnostic_context = diagnostics.DiagnosticContext( - "torch.onnx.dynamo_export", torch.__version__ - ) - dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( - self.registry, diagnostic_context - ) - - graph: torch.fx.Graph = torch.fx.Graph() - x: torch.fx.Node = graph.create_node("placeholder", "x") - x.meta["val"] = torch.tensor(3.0) - b: torch.fx.Node = graph.create_node( - "call_function", target=torch.ops.aten.mul.Tensor, args=(x, x) - ) - c: torch.fx.Node = graph.create_node( - "call_function", target=torch.ops.aten.add.Tensor, args=(b, b) - ) - output: torch.fx.Node = graph.output(c) - module = torch.fx.GraphModule(torch.nn.Module(), graph) - - with self.assertRaises(infra.RuntimeErrorWithDiagnostic): - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, module, dispatcher - ).analyze(infra.levels.ERROR) - - try: - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, module, dispatcher - ).analyze(infra.levels.ERROR) - except infra.RuntimeErrorWithDiagnostic as e: - self.assertIn( - "Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.", - e.diagnostic.message, - ) - @common_utils.instantiate_parametrized_tests class TestDispatcher(common_utils.TestCase): @@ -488,165 +427,5 @@ def test_first_custom_op( self.assertEqual(symbolic_fn, test_third_custom_op) -@common_utils.instantiate_parametrized_tests -class TestOpSchemaWrapper(common_utils.TestCase): - def setUp(self): - # overload type: optional dtype - self.onnx_function_new_full = ops.core.aten_new_full - self.onnx_function_new_full_dtype = ops.core.aten_new_full_dtype - - @common_utils.parametrize( - "inputs, attributes, assertion", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {"alpha": 2.0}, True), - name="perfect_match_with_kwargs", - ), - common_utils.subtest( - (["A", "B"], {}, False), - name="non_perfect_match_due_to_non_tensor_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)], {}, False), - name="non_perfect_match_due_to_too_many_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {"wrong_kwargs": 2.0}, False), - name="non_perfect_match_due_to_wrong_kwargs", - ), - ], - ) - def test_perfect_match_inputs(self, inputs, attributes, assertion): - # OnnxFunction with default attributes - dummy_diagnostic = diagnostics.Diagnostic( - rule=diagnostics.rules.find_opschema_matched_symbolic_function, - level=diagnostics.levels.WARNING, - ) - op_schema_wrapper_add = onnxfunction_dispatcher._OnnxSchemaChecker( - ops.core.aten_add - ) - self.assertEqual( - op_schema_wrapper_add.perfect_match_inputs( - dummy_diagnostic, inputs, attributes - ), - assertion, - ) - - @common_utils.parametrize( - "inputs, kwargs, op, score", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul, 2), - name="match_2_inputs", - ), - common_utils.subtest( - ( - [ - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - ], - {}, - ops.core.aten_mul, - 0, - ), - name="match_0_inputs", - ), - common_utils.subtest( - ([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul_bool, 0), - name="match_0_inputs_bool", - ), - common_utils.subtest( - ( - [ - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(), - ], - {}, - ops.core.aten_mul_bool, - 2, - ), - name="match_2_inputs_bool", - ), - ], - ) - def test_matching_score_system_on_overload_dtypes(self, inputs, kwargs, op, score): - op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) - op_schema_wrapper._record_matching_score(inputs, kwargs) - self.assertEqual(op_schema_wrapper.match_score, score) - - @common_utils.parametrize( - "inputs, kwargs, op, score", - [ - common_utils.subtest( - ([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2), - name="match_2_inputs", - ), - common_utils.subtest( - ( - [torch.randn(3, 4), torch.tensor(3)], - {"dtype": 2}, # at this point, dtype should be converted to int - ops.core.aten_new_full_dtype, - 2, - ), - name="match_2_input_and_match_1_kwargs_optional", - ), - ], - ) - def test_matching_score_system_on_optional_dtypes(self, inputs, kwargs, op, score): - op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op) - op_schema_wrapper._record_matching_score(inputs, kwargs) - self.assertEqual(op_schema_wrapper.match_score, score) - - @common_utils.parametrize( - "value, expected_onnx_str_dtype", - [ - common_utils.subtest( - (1, {"tensor(int64)", "tensor(int16)", "tensor(int32)"}), - name="all_ints", - ), - common_utils.subtest( - (1.0, {"tensor(float)", "tensor(double)", "tensor(float16)"}), - name="all_floats", - ), - common_utils.subtest( - (torch.tensor([True]), {"tensor(bool)"}), - name="bool", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int64), {"tensor(int64)"}), - name="int64", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int32), {"tensor(int32)"}), - name="int32", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.int16), {"tensor(int16)"}), - name="int16", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.float), {"tensor(float)"}), - name="float", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.float16), {"tensor(float16)"}), - name="float16", - ), - common_utils.subtest( - (torch.tensor([1], dtype=torch.double), {"tensor(double)"}), - name="double", - ), - common_utils.subtest((None, set()), name="None"), # None allows no dtype - common_utils.subtest( - ([], set()), name="empaty_list" - ), # Empty list allows no dtype - ], - ) - def test_find_onnx_data_type(self, value, expected_onnx_str_dtype): - self.assertEqual( - onnxfunction_dispatcher._find_onnx_data_type(value), expected_onnx_str_dtype - ) - - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 907db3d96a0e18..4a0ead2ed829b4 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -72,6 +72,28 @@ def test_dynamic_axes_supports_partial_dynamic_shapes(self): }, ) + def test_dynamic_axes_supports_output_names(self): + self.assert_export( + SampleModelForDynamicShapes(), + (torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}), + dynamic_axes={ + "b": [0, 1, 2], + }, + ) + onnx_program = torch.onnx.export( + SampleModelForDynamicShapes(), + ( + torch.randn(2, 2, 3), + torch.randn(2, 2, 3), + ), + input_names=["x", "b"], + output_names=["x_out", "b_out"], + dynamic_axes={"b": [0, 1, 2], "b_out": [0, 1, 2]}, + dynamo=True, + ) + assert onnx_program is not None + onnx_testing.assert_onnx_program(onnx_program) + def test_saved_f_exists_after_export(self): with common_utils.TemporaryFileName(suffix=".onnx") as path: _ = torch.onnx.export( @@ -126,6 +148,63 @@ def test_partial_dynamic_shapes(self): }, ) + def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self): + os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1" + assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1" + + class Nested(torch.nn.Module): + def forward(self, x): + (a0, a1), (b0, b1), (c0, c1, c2) = x + return a0 + a1 + b0 + b1 + c0 + c1 + c2 + + inputs = ( + (1, 2), + ( + torch.randn(4, 4), + torch.randn(4, 4), + ), + ( + torch.randn(4, 4), + torch.randn(4, 4), + torch.randn(4, 4), + ), + ) + + onnx_program = torch.onnx.dynamo_export( + Nested(), + inputs, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ) + assert onnx_program is not None + onnx_testing.assert_onnx_program(onnx_program) + + def test_refine_dynamic_shapes_with_onnx_export(self): + # NOTE: From test/export/test_export.py + + # refine lower, upper bound + class TestRefineDynamicShapeModel(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] >= 6 and y.shape[0] <= 16: + return x * 2.0, y + 1 + + inps = (torch.randn(16), torch.randn(12)) + dynamic_shapes = { + "x": (torch.export.Dim("dx"),), + "y": (torch.export.Dim("dy"),), + } + self.assert_export( + TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes + ) + + def test_zero_output_aten_node(self): + class Model(torch.nn.Module): + def forward(self, x): + torch.ops.aten._assert_async.msg(torch.tensor(True), "assertion failed") + return x + x + + input = torch.randn(2) + self.assert_export(Model(), (input)) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/exporter/test_core.py b/test/onnx/exporter/test_core.py new file mode 100644 index 00000000000000..fc776ecc673b5b --- /dev/null +++ b/test/onnx/exporter/test_core.py @@ -0,0 +1,75 @@ +# Owner(s): ["module: onnx"] +"""Unit tests for the _core module.""" + +from __future__ import annotations + +import numpy as np + +import torch +from torch.onnx._internal.exporter import _core +from torch.testing._internal import common_utils + + +@common_utils.instantiate_parametrized_tests +class TorchTensorTest(common_utils.TestCase): + @common_utils.parametrize( + "dtype, np_dtype", + [ + (torch.bfloat16, np.uint16), + (torch.bool, np.bool_), + (torch.complex128, np.complex128), + (torch.complex64, np.complex64), + (torch.float16, np.float16), + (torch.float32, np.float32), + (torch.float64, np.float64), + (torch.float8_e4m3fn, np.uint8), + (torch.float8_e4m3fnuz, np.uint8), + (torch.float8_e5m2, np.uint8), + (torch.float8_e5m2fnuz, np.uint8), + (torch.int16, np.int16), + (torch.int32, np.int32), + (torch.int64, np.int64), + (torch.int8, np.int8), + (torch.uint16, np.uint16), + (torch.uint32, np.uint32), + (torch.uint64, np.uint64), + (torch.uint8, np.uint8), + ], + ) + def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype): + tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype)) + self.assertEqual(tensor.numpy().dtype, np_dtype) + self.assertEqual(tensor.__array__().dtype, np_dtype) + self.assertEqual(np.array(tensor).dtype, np_dtype) + + @common_utils.parametrize( + "dtype", + [ + (torch.bfloat16), + (torch.bool), + (torch.complex128), + (torch.complex64), + (torch.float16), + (torch.float32), + (torch.float64), + (torch.float8_e4m3fn), + (torch.float8_e4m3fnuz), + (torch.float8_e5m2), + (torch.float8_e5m2fnuz), + (torch.int16), + (torch.int32), + (torch.int64), + (torch.int8), + (torch.uint16), + (torch.uint32), + (torch.uint64), + (torch.uint8), + ], + ) + def test_tobytes(self, dtype: torch.dtype): + tensor = _core.TorchTensor(torch.tensor([1], dtype=dtype)) + self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/exporter/test_tensors.py b/test/onnx/exporter/test_tensors.py new file mode 100644 index 00000000000000..5f1fb0e88b6782 --- /dev/null +++ b/test/onnx/exporter/test_tensors.py @@ -0,0 +1,22 @@ +# Owner(s): ["module: onnx"] +"""Unit tests for the _tensors module.""" + +from __future__ import annotations + +import onnxscript + +from torch.onnx._internal.exporter import _tensors +from torch.testing._internal import common_utils + + +class SymbolicTensorTest(common_utils.TestCase): + def test_it_is_hashable(self): + tensor = _tensors.SymbolicTensor( + opset=onnxscript.values.Opset(domain="test", version=1) + ) + self.assertEqual(hash(tensor), hash(tensor)) + self.assertIn(tensor, {tensor}) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 79d59a1e816816..69c35e44d57555 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -32,7 +32,6 @@ import torch from torch import export as torch_export from torch.onnx import _constants, verification -from torch.onnx._internal.fx import diagnostics from torch.testing._internal import common_utils from torch.testing._internal.opinfo import core as opinfo_core from torch.types import Number @@ -286,35 +285,19 @@ def run_test_with_fx_to_onnx_exporter_and_onnx_runtime( # Feed args and kwargs into exporter. # Note that exporter should flatten kwargs into positional args the exported model; # since ONNX doesn't represent kwargs. - export_error: Optional[torch.onnx.OnnxExporterError] = None - try: - with _dynamo_config.patch(do_not_emit_runtime_asserts=True): - onnx_program = torch.onnx.dynamo_export( - ref_model, - *ref_input_args, - **ref_input_kwargs, - export_options=torch.onnx.ExportOptions( - dynamic_shapes=self.dynamic_shapes, - diagnostic_options=torch.onnx.DiagnosticOptions( - verbosity_level=logging.DEBUG - ), + with _dynamo_config.patch(do_not_emit_runtime_asserts=True): + onnx_program = torch.onnx.dynamo_export( + ref_model, + *ref_input_args, + **ref_input_kwargs, + export_options=torch.onnx.ExportOptions( + dynamic_shapes=self.dynamic_shapes, + diagnostic_options=torch.onnx.DiagnosticOptions( + verbosity_level=logging.DEBUG ), - ) - except torch.onnx.OnnxExporterError as e: - export_error = e - onnx_program = e.onnx_program - - if diagnostics.is_onnx_diagnostics_log_artifact_enabled(): - onnx_program.save_diagnostics( - f"test_report_{self._testMethodName}" - f"_dynamic_axes_{self.dynamic_shapes}" - f"_model_type_{self.model_type}" - ".sarif" + ), ) - if export_error is not None: - raise export_error - if not skip_dynamic_shapes_check: assert_dynamic_shapes(onnx_program, self.dynamic_shapes) diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py deleted file mode 100644 index 9972ea901722c7..00000000000000 --- a/test/onnx/test_fx_op_consistency.py +++ /dev/null @@ -1,2076 +0,0 @@ -# Owner(s): ["module: onnx"] - -"""Test consistency between the output values of torch.onnx FX exported operators -and torch operators given the same inputs. - -Usage: - - 1. Test all operators: - - pytest test/onnx/test_fx_op_consistency.py - - 2. To run tests on a specific operator (e.g. torch.ceil): - - pytest test/onnx/test_fx_op_consistency.py -k ceil - pytest test/onnx/test_fx_op_consistency.py -k nn_functional_scaled_dot_product_attention - - 3. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of errors. E.g. - - CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/test_fx_op_consistency.py -k div_mode_int - - NOTE: Read more on Running and writing tests: - https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests - -Note: - - 1. Please make sure pytest-subtests is installed. Otherwise, the sub-tests will be ignored. - - 2. Install pytest-xdist to run tests in parallel if runng all tests is the goal. - - 3. When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES and - TESTED_OPS lists. See "Modify this section" - -""" - -from __future__ import annotations - -import copy -import itertools -import os -from typing import ( - Any, - Callable, - Collection, - List, - Mapping, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) - -import error_reproduction -import onnx_test_common -import parameterized -import pytest -import pytorch_test_common -from onnx_test_common import skip, skip_slow, xfail - -import torch -from torch.onnx._internal.diagnostics import _rules -from torch.testing._internal import ( - common_device_type, - common_methods_invocations, - common_utils, -) - - -if TYPE_CHECKING: - from torch.testing._internal.opinfo import core as opinfo_core - - -# NOTE: For ATen signature modifications that will break ONNX export, -# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip -# to make the signal apparent for maintainers. -def xfail_torchlib_forward_compatibility( - op_name: str, - variant_name: str = "", - *, - reason: str, - github_issue: str, - opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, - dtypes: Optional[Collection[torch.dtype]] = None, - matcher: Optional[Callable[[Any], bool]] = None, - enabled_if: bool = True, -): - """Prefer using this (xfail) over skip when possible. - - Only skip when the test is not failing consistently. - """ - return xfail( - op_name, - variant_name=variant_name, - reason=f"{reason}. GitHub Issue: {github_issue}", - opsets=opsets, - dtypes=dtypes, - matcher=matcher, - enabled_if=enabled_if, - ) - - -def skip_torchlib_forward_compatibility( - op_name: str, - variant_name: str = "", - *, - reason: str, - github_issue: str, - opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None, - dtypes: Optional[Collection[torch.dtype]] = None, - matcher: Optional[Callable[[Any], Any]] = None, - enabled_if: bool = True, -): - """Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible. - - Only skip when the test is not failing consistently. - """ - return skip( - op_name, - variant_name=variant_name, - reason=f"{reason}. GitHub Issue: {github_issue}", - opsets=opsets, - dtypes=dtypes, - matcher=matcher, - enabled_if=enabled_if, - ) - - -# fmt: off -# Turn off black formatting to keep the list compact - -# Expected failures for onnx export. -# The list should be sorted alphabetically by op name. -# Q: When should I use fixme vs vs skip vs xfail? -# A: Prefer xfail over skip when possible. -# 2a. If a test is now failing because of xpass, because some previous errors -# are now fixed, removed the corresponding xfail. -# 2b. If a test is not failing consistently, use skip. -# NOTE: EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES only supports dtypes. If a matcher or model_type -# is needed, use the SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE list further down below. -EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] = ( - xfail( - "__getitem__", - reason="io_adaper doesn't support __getitem__ input slice(0, 3, None)", - ), - xfail( - "__radd__", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "bool"), - ), - xfail( - "__rmatmul__", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "__rpow__", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Pow", "int"), - ), - skip( - "_native_batch_norm_legit", - reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ - https://github.com/microsoft/onnxscript/pull/1289") - ), - skip( - "_batch_norm_with_update", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch and type error", - ), - xfail( - "_softmax_backward_data", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "_unsafe_masked_index", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("_unsafe_masked_index", "complex64"), - ), - xfail( - "_unsafe_masked_index_put_accumulate", - reason="fixme: Status Message: updates tensor should have shape equal to " - "indices.shape[:-1] + data.shape[indices.shape[-1]:]", - ), - xfail( - "_unsafe_masked_index_put_accumulate", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "add", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Add") - ), - xfail( - "add", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_script_does_not_support( - "Add", "int8, int16, uint8 have type issue." - ), - ), - xfail( - "addbmm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addbmm", "complex64") - ), - xfail( - "addmm", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Addmm") - ), - xfail( - "addmm", - variant_name="decomposed", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Addmm") - ), - skip( - "addmm", dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") - ), - skip( - "addmm", - variant_name="decomposed", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addmm", "complex64 (core dump)") - ), - xfail( - "addr", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support( - "Addr", "bool" - ), - ), - xfail( - "addr", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Addr", "complex64") - ), - xfail( - "alias_copy", - dtypes=(torch.int8, torch.uint8, torch.int16, torch.float64), - reason="OnnxExporterError: Failed to export model", - ), - xfail( - "allclose", - reason=onnx_test_common.reason_dynamo_does_not_support("Allclose") - ), - xfail( - "amax", - dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), - ), - xfail( - "amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16") - ), - xfail( - "aminmax", - dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES), - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"), - ), - xfail( - "arange", - dtypes=(torch.uint8,), - reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"), - ), - xfail( - "arange", - dtypes=(torch.int16, torch.int32), - reason="AssertionError: The values for attribute 'shape' do not match", - ), - xfail( - "argmax", - dtypes=( - torch.int16, - torch.int64, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ArgMax", "int16, int64" - ), - ), - xfail( - "argmin", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - torch.int64, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ArgMin", "uint8, int8, int16, int64" - ), - ), - xfail( - "argwhere", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "as_strided", - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults", - ), - xfail( - "atan2", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "baddbmm", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Matmul", "uint8, int8, int16" - ), - ), - xfail( - "baddbmm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("baddbmm", "complex64") - ), - xfail( - "bernoulli", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "bfloat16", - reason="fixme: ORT errors with RuntimeError: No corresponding Numpy type for Tensor Type.", - ), - xfail( - "bincount", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.bincount.default"), - ), - xfail( - "block_diag", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Block_diag", "complex"), - ), - xfail( - "bmm", - dtypes=( - torch.uint8, - torch.int8, - torch.int16, - ), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Matmul", "uint8, int8, int16" - ), - ), - xfail( - "broadcast_shapes", - reason=onnx_test_common.reason_dynamo_does_not_support("output is int"), - ), - xfail( - "cauchy", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - skip( - "ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int") - ), - xfail( - "chalf", - reason="fixme: ONNX shape type inference error: Invalid tensor data type 0." - ), - xfail( - "chunk", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Chunk", "uint8, int8, int16" - ), - ), - xfail( - "clamp", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_max", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool") - ), - xfail( - "clamp_max", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_min", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Max", "uint8, int8, int16" - ), - ), - xfail( - "clamp_min", dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool") - ), - xfail( - "constant_pad_nd", - dtypes=(torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Constant_pad_nd", "int16" - ), - ), - xfail( - "constant_pad_nd", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support( - "Constant_pad_nd", "complex64" - ), - ), - xfail( - "corrcoef", - reason=onnx_test_common.reason_dynamo_does_not_support( - "aten.equal.default" - ), - ), - xfail( - "cov", - reason=onnx_test_common.reason_dynamo_does_not_support( - "aten.equal.default" - ), - ), - xfail( - "cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16") - ), - xfail( - "combinations", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked.select"), - ), - xfail( - "diag", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), - ), - xfail( - "diagonal_copy", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Diagonal", "bool"), - ), - xfail( - "dot", dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16") - ), - skip( - "dot", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("Dot", "complex64(core dump)"), - ), - xfail( - "empty", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - xfail( - "empty_strided", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "eq", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"), - ), - xfail( - "equal", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "exponential", - reason=onnx_test_common.reason_dynamo_does_not_support("exponential"), - ), - xfail( - "fft.fft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.fft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.fftn", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.ifftn", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfft", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfft2", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "fft.irfftn", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfft", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfftn", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "fft.rfft2", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "floor", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), - ), - xfail( - "floor_divide", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"), - ), - xfail( - "full", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("full", "complex64") - ), - xfail( - "full_like", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("full_like", "complex64") - ), - xfail( - "gather", - reason="GatherElements op: Rank of input 'data' needs to be equal to rank of input 'indices'" - ), - xfail( - "geometric", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "heaviside", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Heaviside", "bool"), - ), - xfail( - "index_add", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "int64, int32, bool"), - ), - xfail( - "index_fill", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("index_fill", "complex64") - ), - xfail( - "index_fill", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES, - reason="fixme: Constant input list has None. ONNXScript does not support None in constant list." - ), - xfail( - "index_put", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"), - ), - xfail( - "index_put", - dtypes=(torch.uint8, torch.int8, torch.int16,), - reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"), - ), - xfail( - "index_put", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterND", "float16"), - ), - xfail( - "isnan", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("IsNaN", "int, bool"), - ), - xfail( - "istft", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "item", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "lerp", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("lerp", "complex64") - ), - xfail( - "linalg.lstsq", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), - ), - xfail( - "linalg.lstsq", - variant_name="grad_oriented", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.linalg_lstsq.default"), - ), - xfail( - "linalg.matrix_power", - reason="fixme: The values for attribute 'shape' do not match: torch.Size([2, 2]) != torch.Size([2, 2, 2])." - ), - xfail( - "linalg.norm", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "linalg.norm", - variant_name="subgradients_at_zero", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "linalg.vecdot", - reason="fixme: Assertion error: result shape mismatch", - ), - xfail( - "linspace", - dtypes=(torch.int64, torch.int32,), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), - xfail( - "linspace", - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32,), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ), - xfail( - "linspace", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") - ), - xfail( - "linspace", - variant_name="tensor_overload", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("linspace", "complex64") - ), - xfail( - "log_normal", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "log_softmax", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "log_softmax", - variant_name="with_dtype", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "logical_and", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("And", "float, int"), - ), - xfail( - "logical_not", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Not", "float, int"), - ), - xfail( - "logical_or", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Or", "float, int"), - ), - xfail( - "logical_xor", - dtypes=onnx_test_common.FLOAT_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Xor", "float, int"), - ), - skip( - "masked.logsumexp", - reason="fixme: https://github.com/onnx/onnx/issues/4986", - ), - xfail( - "masked.amax", - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.amin", - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.argmin", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked.argmax", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.FLOAT_TYPES + (torch.int64,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked_fill", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked.sum", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked.log_softmax", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "masked.mean", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("ReduceMean", "bool"), - ), - xfail( - "masked.norm", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "masked.prod", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "masked_select", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.masked_select.default"), - ), - xfail( - "max", - variant_name="reduction_no_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), - ), - xfail( - "max", - variant_name="reduction_with_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMax", "bool"), - ), - xfail( - "max", - variant_name="reduction_with_dim", - dtypes=(torch.int64,), - reason="https://github.com/onnx/onnx/issues/4986", - ), - xfail( - "min", - variant_name="reduction_no_dim", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), - ), - xfail( - "min", - variant_name="reduction_with_dim", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ReduceMin", "bool"), - ), - skip( - "mm", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("MM", "complex64(core dump)"), - ), - xfail( - "multinomial", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nanquantile", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "nansum", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("IsNaN", "int, bool"), - ), - xfail( - "narrow", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - skip( - "native_batch_norm", - reason=onnx_test_common.reason_onnx_script_does_not_support("cpu is not supported: \ - https://github.com/microsoft/onnxscript/pull/1289") - ), - xfail( - "native_layer_norm", - dtypes=(torch.float16,), - reason="fixme: ORT optimizer error: https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "new_full", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason=onnx_test_common.reason_dynamo_does_not_support("new_full", "complex64") - ), - xfail( - "nn.functional.adaptive_avg_pool2d", - reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \ - maximum recursion depth exceeded while calling a Python object"), - ), - xfail( - "nn.functional.adaptive_avg_pool3d", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"), - ), - xfail( - "nn.functional.alpha_dropout", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.avg_pool1d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.avg_pool2d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.avg_pool3d", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"), - ), - xfail( - "nn.functional.batch_norm", - dtypes=(torch.float16,), - reason="fixme: https://github.com/microsoft/onnxscript/issues/1270", - ), - xfail( - "nn.functional.conv_transpose1d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), - ), - xfail( - "nn.functional.conv_transpose2d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), - ), - xfail( - "nn.functional.conv_transpose3d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), - ), - skip( - "nn.functional.conv_transpose1d", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "nn.functional.conv_transpose2d", - reason="fixme: Assertion error: result mismatch", - ), - skip( - "nn.functional.conv_transpose3d", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.conv1d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"), - ), - xfail( - "nn.functional.conv2d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"), - ), - xfail( - "nn.functional.conv2d", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.conv3d", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_does_not_support("Conv3d", "int64"), - ), - xfail( - "nn.functional.conv3d", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.cosine_embedding_loss", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("CosineEmbeddingLoss", "bool"), - ), - xfail( - "nn.functional.ctc_loss", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.ctc_loss.default"), - ), - xfail( - "nn.functional.dropout", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.dropout2d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.dropout3d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.feature_alpha_dropout", - variant_name="with_train", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.feature_alpha_dropout", - variant_name="without_train", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.fractional_max_pool2d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.fractional_max_pool3d", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.gaussian_nll_loss", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.gaussian_nll_loss"), - ), - xfail( - "nn.functional.grid_sample", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.group_norm", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("GroupNormalization", "float16"), - ), - xfail( - "nn.functional.local_response_norm", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("avgpool", "int64"), - ), - xfail( - "nn.functional.linear", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Gemm", "int"), - ), - xfail( - "nn.functional.max_pool2d", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"), - ), - xfail( - "nn.functional.max_pool3d", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"), - ), - xfail( - "nn.functional.multi_head_attention_forward", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.one_hot", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "nn.functional.pad", - variant_name="replicate", - reason="fixme: ORT error: padding size", - ), - xfail( - "nn.functional.pad", - variant_name="replicate_negative", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.pad", - variant_name="reflect", - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nn.functional.pixel_shuffle", - dtypes=(torch.int32, torch.int64) + onnx_test_common.BOOL_TYPES, - reason="fixme: ONNX Runtime does not support int32/64 inputs", - ), - xfail( - "nn.functional.pixel_unshuffle", - reason=onnx_test_common.reason_onnx_script_does_not_support("aten.pixel_unshuffle.default"), - ), - xfail( - "nn.functional.poisson_nll_loss", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason="fixme: result mismatch with NaN.", - ), - xfail( - "nn.functional.rrelu", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "nn.functional.rrelu", - dtypes=(torch.int64,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Relu", "int64"), - ), - skip( - "nn.functional.scaled_dot_product_attention", - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ), - xfail( - "nn.functional.scaled_dot_product_attention", - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - ), - xfail( - "nn.functional.selu", - reason="fixme: nn.functional.selu is not in torch._decomp.decomposition_table", - ), - xfail( - "nn.functional.soft_margin_loss", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "nonzero", - dtypes=(torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"), - ), - xfail( - "normal", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "normal", - variant_name="in_place", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "normal", - variant_name="number_mean", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "ones", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - xfail( - "pca_lowrank", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "quantile", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default") - ), - xfail( - "rand_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randint", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randint_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randn", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "randn_like", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "resize_", - reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") - ), - xfail( - "resize_as_", - reason=onnx_test_common.reason_dynamo_does_not_support("resize_as_") - ), - xfail( - "round", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Round", "int"), - ), - xfail( - "rsub", - dtypes=(torch.uint8, torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Mul", "uint8, int8, int16" - ), - ), - xfail( - "scatter_add", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="sum", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="prod", - dtypes=(torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="amin", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="amax", - dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"), - ), - xfail( - "scatter_reduce", - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ), - xfail( - "sgn", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), - ), - xfail( - "sign", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Sign", "bool"), - ), - xfail( - "signal.windows.kaiser", - reason=onnx_test_common.reason_dynamo_does_not_support("functionalization"), - ), - xfail( - "softmax", - dtypes=(torch.float16,), - reason="ORT error: https://github.com/microsoft/onnxruntime/issues/16438" - ), - xfail( - "sparse.mm", - variant_name="reduce", - reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), - ), - xfail( - "sparse.sampled_addmm", - reason=onnx_test_common.reason_dynamo_does_not_support("InternalTorchDynamoError: Sparse CSR tensors do not have strides"), - ), - xfail( - "special.erfcx", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Erf", "int, bool"), - ), - xfail( - "special.erfcx", - dtypes=onnx_test_common.FLOAT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Erfcx"), - ), - xfail( - "special.log_ndtr", - dtypes=onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES, - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "special.ndtr", - dtypes=(torch.float16,), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "square", - dtypes=(torch.int8, torch.uint8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"), - ), - xfail( - "squeeze", - variant_name="multiple", - reason="fixme: https://github.com/microsoft/onnxscript/issues/1264", - ), - xfail( - "svd_lowrank", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "stft", - reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"), - ), - xfail( - "sub", - dtypes=(torch.uint8, torch.int8, torch.int16), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Mul", "uint8, int8, int16" - ), - ), - xfail( - "take", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "tensor_split", - reason=onnx_test_common.reason_dynamo_does_not_support("data-dependent"), - ), - xfail( - "topk", - dtypes=(torch.int64, torch.int32, torch.float16), - reason="fixme: Assertion error: result mismatch", - ), - xfail( - "tril", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), - ), - xfail( - "triu", - dtypes=onnx_test_common.BOOL_TYPES + (torch.int32,), - reason=onnx_test_common.reason_onnx_runtime_does_not_support("trilu", "bool, int32"), - ), - xfail( - "trunc", - dtypes=onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Floor", "int"), - ), - xfail( - "unflatten", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_does_not_support("Unflatten") - ), - xfail( - "uniform", - reason=onnx_test_common.reason_dynamo_does_not_support("wrapper_set_seed"), - ), - xfail( - "unique", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), - ), - xfail( - "unique_consecutive", - reason=onnx_test_common.reason_dynamo_does_not_support("aten.unique_consecutive.default"), - ), - xfail( - "unravel_index", - dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES, - reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"), - ), - xfail( - "unsqueeze_copy", - reason="OnnxExporterError: Failed to export model", - dtypes=(torch.int8, torch.uint8, torch.int16), - ), - xfail( - "where", - dtypes=onnx_test_common.BOOL_TYPES, - reason=onnx_test_common.reason_onnx_runtime_does_not_support("Where", "bool"), - ), - xfail( - "zeros", - dtypes=onnx_test_common.COMPLEX_TYPES, - reason="fixme: kwargs dtpye=complex64 is not supported in ONNX." - ), - # SLOW TESTS (All are xfails if we run them) - # TODO: https://github.com/pytorch/pytorch/issues/117118 - skip_slow( - "cdist", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "histogram", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "histogramdd", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.lu_solve", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.solve_triangular", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "linalg.svd", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "logspace", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "logspace", - variant_name="tensor_overload", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "max_pool2d_with_indices_backward", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.interpolate", - variant_name="bicubic", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool1d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool2d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_unpool3d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool1d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool2d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.max_pool3d", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "nn.functional.unfold", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "ormqr", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "searchsorted", - reason="fixme: Test sets are too many.", - ), - skip_slow( - "svd", - reason="fixme: Test sets are too many.", - ), -) -# fmt: on - -# NOTE: The xfail and skip with a matcher function or model_type should be -# at under the `SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE` section. -SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: tuple[ - onnx_test_common.DecorateMeta, ... -] = ( - skip( - "_native_batch_norm_legit", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - # TODO: This test currently fails only for certain inputs, e.g. shape([3, 1]). - # Numerically the ONNX program is correct, but the output shapes for `save_mean` - # and `save_var` were tensor(-2.1268) instead of the correct tensor([-2.1268]) - # for example. - skip( - "_batch_norm_with_update", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - reason="not supported yet", - ), - xfail( - "addmm", # xfail can't only use dtypes to catch all cases - matcher=lambda sample: sample.input.dtype - in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Gemm", "uint8, int8, int16, int32, int64" - ), - ), - xfail( - "addmm", - matcher=lambda sample: sample.args[0].numel() == 0, - reason="ONNX Runtime does not support empty tensors multiplication", - ), - xfail( - "addmm", - variant_name="decomposed", - matcher=lambda sample: sample.args[0].numel() == 0, - reason="ONNX Runtime does not support empty tensors multiplication", - ), - xfail( - "amax", - matcher=lambda sample: len(sample.input.shape) == 0 - and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), - reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - xfail( - "amin", - matcher=lambda sample: len(sample.input.shape) == 0 - and (sample.kwargs.get("dim") is not None and sample.kwargs.get("dim") != ()), - reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - xfail( - "aminmax", - matcher=lambda sample: len(sample.input.shape) == 0 - and sample.kwargs.get("dim") is not None, - reason="Op (ReduceMin) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0", - ), - skip( - "cat", - matcher=lambda sample: sample.input[0].equal(torch.tensor([])), - reason="core dump - cat does not support zero-dim tensors yet", - ), - xfail( - "index_add", - matcher=lambda sample: len(sample.input.shape) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ScatterND", "0-D tensor" - ), - ), - xfail( - "index_add", - matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, - reason="fixme: aten::index_put indices contains None when dim is -1", - ), - xfail( - "index_copy", - matcher=lambda sample: len(sample.input.shape) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "ScatterND", "0-D tensor" - ), - ), - xfail( - "index_copy", - matcher=lambda sample: isinstance(sample.args[0], int) and sample.args[0] == -1, - reason="fixme: aten::index_put indices contains None when dim is -1", - ), - xfail( - "index_put", - matcher=lambda sample: (sample.args[0][0].dtype == torch.bool) - and (sample.kwargs.get("accumulate") is False), - reason=onnx_test_common.reason_dynamo_does_not_support( - "https://github.com/pytorch/pytorch/issues/101150" - ), - ), - skip( - "linalg.multi_dot", - matcher=lambda sample: sum(torch.numel(input) for input in sample.input) == 0, - reason="fixme: Undefined", - ), - skip( - "log_softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "log_softmax", - variant_name="with_dtype", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "masked.log_softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - skip( - "matmul", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason="values of matmul of [m, 0] and [0, n] matrices are undefined", - ), - skip( - "mm", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason="values of matmul of [m, 0] and [0, n] matrices are undefined", - ), - xfail( - "native_batch_norm", - matcher=lambda sample: sample.args[-3] is True - and any(arg is not None for arg in sample.args[2:4]), - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - xfail( - "nn.functional.avg_pool1d", - matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) - and ( - sample.kwargs.get("count_include_pad") is True - or sample.input.shape[2] - % ( - sample.args[0][0] - if isinstance(sample.args[0], tuple) - else sample.args[0] - ) - != 0 - ), - reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", - ), - xfail( - "nn.functional.avg_pool2d", - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), - reason="ONNX doesn't support divisor_override argument", - ), - xfail( - "nn.functional.avg_pool3d", - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, - reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", - ), - xfail( - "nn.functional.avg_pool3d", - matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) - or (sample.kwargs.get("divisor_override") is not None), - reason="ONNX doesn't support divisor_override argument", - ), - xfail( - "nn.functional.batch_norm", - matcher=lambda sample: sample.kwargs.get("training") is True - and any(arg is not None for arg in sample.args[2:4]), - reason="Flaky failure: https://github.com/pytorch/pytorch/issues/115106", - ), - xfail( - "nn.functional.conv2d", - matcher=lambda sample: sample.kwargs.get("padding") == "valid", - reason="fixme: https://github.com/pytorch/pytorch/issues/117054", - ), - xfail( - "nn.functional.conv3d", - matcher=lambda sample: sample.kwargs.get("padding") == "valid", - reason="fixme: https://github.com/pytorch/pytorch/issues/117054", - ), - skip( - "nn.functional.cross_entropy", - matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int), - reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type", - ), - xfail( - "nn.functional.embedding", - matcher=lambda sample: sample.kwargs.get("max_norm") is not None, - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="https://github.com/pytorch/pytorch/issues/115106", - ), - skip_torchlib_forward_compatibility( - "nn.functional.embedding_bag", - matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True, - reason=onnx_test_common.reason_onnx_script_does_not_support( - "'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. " - "'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided" - ), - github_issue="https://github.com/microsoft/onnxscript/issues/1056", - ), - xfail( - "nn.functional.group_norm", - matcher=lambda sample: torch.numel(sample.input) == 0, - reason=onnx_test_common.reason_onnx_runtime_does_not_support( - "Reshape", "empty tensor" - ), - ), - xfail( - "nn.functional.instance_norm", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - matcher=lambda sample: sample.kwargs.get("running_mean") is not None, - reason="fixme: KeyError: 'self___kwargs__running_mean'", - ), - xfail( - "nn.functional.max_pool3d", - matcher=lambda sample: sample.kwargs.get("ceil_mode") is True - and sample.kwargs.get("padding") == 1, - reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed", - ), - xfail( - "nn.functional.pixel_shuffle", - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", - ), - xfail( - "nonzero", - matcher=lambda sample: len(sample.input.shape) == 0 - and sample.kwargs.get("as_tuple", False) is False, - reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).", - model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - ), - xfail( - "scatter_add", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch", - ), - skip( - "scatter_reduce", - variant_name="amax", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="amin", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="prod", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "scatter_reduce", - variant_name="sum", - # ONNX has not include_self parameter and default is include_self=True mode - matcher=lambda sample: sample.kwargs.get("include_self") is False, - reason="ONNX does't support include_self=False option", - ), - skip( - "softmax", - matcher=lambda sample: len(sample.input.shape) == 0, - reason="fixme: LogSoftMax does not support empty tensor as input", - ), - xfail( - "unflatten", - reason="Logic not implemented for size 0 inputs in op.Reshape", - matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), - ), - skip( - "signal.windows.hamming", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.general_hamming", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.blackman", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.general_cosine", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.hann", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), - skip( - "signal.windows.nuttall", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - reason="does not match node name", - ), -) - -OPS_DB = copy.deepcopy(common_methods_invocations.op_db) -OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset( - meta.op_name for meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE -) -ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB) - - -def _torch_size_flatten_spec(d: List[Any], spec: Any) -> List[Any]: - return [d[i] for i in range(spec.num_children)] - - -torch.fx._pytree.register_pytree_flatten_spec( - torch.Size, - _torch_size_flatten_spec, -) - - -class SingleOpModel(torch.nn.Module): - """Test model to wrap around a single op for export.""" - - def __init__(self, op, kwargs): - super().__init__() - self.operator = op - self.kwargs = kwargs - - def forward(self, *args): - return self.operator(*args, **self.kwargs) - - -def _should_skip_xfail_test_sample( - op_name: str, - variant_test_name: str, - sample, - model_type: pytorch_test_common.TorchModelType, -) -> Tuple[Optional[str], Optional[str]]: - """Check if the test sample should be skipped or xfailed. - - If the xfail/skip decorator meta is matched with its op_name and model_type, - return the test_behavior and reason. Otherwise, return None, None. Note that - if the matcher is None, the test is decorator_meta is meant to skip/xfail all model types. - - Args: - op_name: The name of the op. - sample: The test sample. - model_type: The model type of the test. - - Returns: - A tuple of (test_behavior, reason). test_behavior is either "skip" or "xfail". - reason is the reason for the test_behavior. - """ - - if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS: - return None, None - for decorator_meta in SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE: - # Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS_WITH_MATCHER_AND_MODEL_TYPE. That's fine because the list is small. - # NOTE: If model_type is None, the test is decorator_meta is meant to skip/xfail all model types. - if ( - decorator_meta.op_name == op_name - and decorator_meta.variant_name == variant_test_name - ) and ( - model_type == decorator_meta.model_type or decorator_meta.model_type is None - ): - if decorator_meta.matcher is None and decorator_meta.model_type is None: - raise TypeError( - "Either Matcher or model_type must be defined in sub xfail and skip." - ) - if decorator_meta.matcher is not None and decorator_meta.matcher(sample): - return decorator_meta.test_behavior, decorator_meta.reason - elif decorator_meta.matcher is None: - # xfail/skip the whole test of the model type without matcher - return decorator_meta.test_behavior, decorator_meta.reason - return None, None - - -def _compare_onnx_and_torch_exported_program( - torch_exported_program, - onnx_exported_program, - input_args, - input_kwargs=None, - test_name=None, - sample_num=None, - sample_kwargs=None, - rtol=1e-03, - atol=1e-07, - only_check_shape=False, -): - # avoid mutable default argument - if input_kwargs is None: - input_kwargs = {} - - # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. - # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. - # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() - onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) - if isinstance(torch_exported_program, torch.export.ExportedProgram): - torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) - else: - torch_outputs = torch_exported_program(*input_args, **input_kwargs) - torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( - torch_outputs - ) - if len(torch_outputs_onnx_format) != len(onnx_outputs): - raise AssertionError( - f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" - ) - - for j, (torch_output, onnx_output) in enumerate( - zip(torch_outputs_onnx_format, onnx_outputs) - ): - if only_check_shape: - assert torch_output.shape == onnx_output.shape - else: - try: - torch.testing.assert_close( - torch.tensor(onnx_output), - torch_output, - rtol=rtol, - atol=atol, - equal_nan=True, - ) - except AssertionError as e: - if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1": - error_reproduction.create_mismatch_report( - test_name, - sample_num, - onnx_exported_program.model_proto, - input_args, - sample_kwargs, - torch.tensor(onnx_output), - torch_output, - e, - ) - if len(torch_outputs_onnx_format) > 1: - raise AssertionError(f"Output {j} mismatch") from e - raise - - -def _run_test_output_match( - test_suite: onnx_test_common._TestONNXRuntime, - device: str, - dtype: torch.dtype, - op: opinfo_core.OpInfo, -): - # device is provided by instantiate_device_type_tests, but we only want to run in cpu. - assert device == "cpu" - samples = op.sample_inputs( - device, - dtype, - requires_grad=False, - ) - for i, cpu_sample in enumerate(samples): - inputs = (cpu_sample.input, *cpu_sample.args) - # Provide the repr to subtest because tensors are not serializable in parallel test runs - - with test_suite.subTest( - opset=test_suite.opset_version, - sample_num=i, - inputs=repr(inputs), - kwargs=repr(cpu_sample.kwargs), - ): - test_behavior, reason = _should_skip_xfail_test_sample( - op.name, op.variant_test_name, cpu_sample, test_suite.model_type - ) - with onnx_test_common.normal_xfail_skip_test_behaviors( - test_behavior, reason - ): - model = SingleOpModel(op.op, cpu_sample.kwargs) - model.eval() - - if ( - dtype == torch.float32 - and op.name in test_suite.fp32_low_precision_dict - ): - rtol = test_suite.fp32_low_precision_dict[op.name][0] - atol = test_suite.fp32_low_precision_dict[op.name][1] - elif dtype == torch.float32: - # Relax atol and rtol for float32 based on empirical results - rtol = 1e-5 - atol = 2e-5 - elif ( - dtype == torch.float16 - and (op.name, op.variant_test_name) - in test_suite.fp16_low_precision_variant_dict - ): - rtol = test_suite.fp16_low_precision_variant_dict[ - (op.name, op.variant_test_name) - ][0] - atol = test_suite.fp16_low_precision_variant_dict[ - (op.name, op.variant_test_name) - ][1] - elif ( - dtype == torch.float16 - and op.name in test_suite.fp16_low_precision_dict - ): - rtol = test_suite.fp16_low_precision_dict[op.name][0] - atol = test_suite.fp16_low_precision_dict[op.name][1] - else: - rtol = None - atol = None - - if ( - test_suite.model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - try: - # TODO (tugsbayasgalan) Migrate to pre-dispatch IR - # BUG1: python test/onnx/test_fx_op_consistency.py -k test_output_match_triu_cpu_int32 - # has unexpected success, but don't know how to remove from xfail list - # BUG2: User output to_sparse is not in the correct order or is not found in the - # exported program's user_output list (https://github.com/pytorch/pytorch/issues/124328) - # python test/onnx/test_fx_op_consistency.py -k test_output_match_to_sparse_cpu_float32 - # BUG3: [ShapeInferenceError] Inference error(s): (op_type:aten_view, node name: aten_view_4): - # [ShapeInferenceError] - # Inference error(s): (op_type:Reshape, node name: n1): [ShapeInferenceError] Invalid position of 0. - # python test/onnx/test_fx_op_consistency.py -k test_output_match_stack_cpu_int32 - from torch.export import _trace - - model = _trace._export(model, inputs, pre_dispatch=False) - - except AssertionError as e: - # NOTE: avoid fake_mode detection bug in torch.export.export - pytest.xfail( - onnx_test_common.reason_dynamo_does_not_support(str(e)) - ) - - try: - onnx_program = torch.onnx.dynamo_export( - model, - *inputs, - ) - except torch.onnx.OnnxExporterError as e: - # NOTE: If the model has unsupported nodes, we will skip the test - # with non-strict xfail. Otherwise, we will raise the error. - if hasattr( - e.__cause__, "diagnostic" - ) and e.__cause__.diagnostic.rule in ( - _rules._POERules.no_symbolic_function_for_call_function, - _rules._POERules.unsupported_fx_node_analysis, - ): - pytest.xfail( - onnx_test_common.reason_onnx_script_does_not_support(str(e)) - ) - else: - raise e - _compare_onnx_and_torch_exported_program( - model, - onnx_program, - inputs, - test_name=test_suite.id(), - sample_num=i, - sample_kwargs=cpu_sample.kwargs, - rtol=rtol, - atol=atol, - only_check_shape=(op.name in test_suite.only_shape_check_list), - ) - - -def _parameterized_class_attrs_and_values(): - input_values = [] - input_values.extend( - itertools.product( - (opset for opset in onnx_test_common.FX_TESTED_OPSETS), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), - ) - ) - return { - "attrs": ["opset_version", "model_type"], - "input_values": input_values, - } - - -def _parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]): - """Combine class name with the parameterized arguments. - - This function is passed to `parameterized.parameterized_class` as the - `class_name_func` argument. - """ - suffixes = [] - for k, v in input_dicts.items(): - suffixes.append(f"{k}_{v}") - return f"{cls.__name__}_{'_'.join(suffixes)}" - - -@parameterized.parameterized_class( - **_parameterized_class_attrs_and_values(), - class_name_func=_parameterize_class_name, -) -class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime): - """Test output consistency between exported ONNX models and PyTorch eager mode. - - This is a parameterized test suite. - """ - - opset_version = -1 - dynamic_shapes: bool = False - model_type: pytorch_test_common.TorchModelType = ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE - ) - - # NOTE: Follow torchlib settings in ops_test_data.py - only_shape_check_list = [ - "empty", - "empty_like", - "empty_strided", - "new_empty", - "new_empty_strided", - ] - - fp32_low_precision_dict = { - "native_layer_norm": [2e-4, 7e-4], - } - - fp16_low_precision_dict = { - "addbmm": [2e-1, 2e-2], - "addcdiv": [3e-2, 1.4e-3], - "addcmul": [3e-2, 1e-3], - "addmv": [5e-2, 3e-2], - "addr": [3e-3, 4e-3], - "baddbmm": [3e-2, 1e-3], - "cumulative_trapezoid": [3e-2, 1e-3], - "cross": [3e-2, 2e-2], - "diff": [1e-2, 5e-2], - "div": [5e-3, 1e-3], - "gradient": [3e-3, 4e-3], - "linalg.cross": [1e-3, 2e-2], - "linalg.multi_dot": [3e-2, 1e-3], - "linalg.vecdot": [1e-2, 2e-2], - "linspace": [2e-2, 2e-3], - "masked.std": [2e-2, 2e-3], - "masked.var": [2e-2, 2e-2], - "matmul": [2e-2, 6e-2], - "mv": [9e-3, 1e-5], - "nn.functional.batch_norm": [3e-2, 1e-3], - "nn.functional.binary_cross_entropy": [3e-2, 1e-3], - "nn.functional.binary_cross_entropy_with_logits": [4e-2, 4e-3], - "nn.functional.cosine_similarity": [3e-2, 1e-3], - "nn.functional.cosine_embedding_loss": [1e-2, 1e-3], - "nn.functional.hardsigmoid": [1e-3, 5e-3], - "nn.functional.hardswish": [1e-3, 5e-3], - "nn.functional.hinge_embedding_loss": [4e-1, 3e-3], - "nn.functional.huber_loss": [1e-2, 1e-1], - "nn.functional.instance_norm": [1e-2, 1e-3], - "nn.functional.interpolate": [1e-2, 1e-3], - "nn.functional.kl_div": [2e-3, 2e-4], - "nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3], - "nn.functional.local_response_norm": [1e-2, 5e-3], - "nn.functional.poisson_nll_loss": [4e-2, 6e-3], - "nn.functional.nll_loss": [3e-2, 1e-3], - "nn.functional.triplet_margin_loss": [2e-2, 1e-2], - "nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2], - "native_batch_norm": [3e-2, 1e-3], - "norm": [1e-2, 1e-2], - "dot": [3e-2, 1e-3], - "logit": [3e-2, 1e-3], - "rsub": [3e-2, 1e-3], - "sinc": [2e-1, 6e-4], - "sub": [3e-2, 1e-3], - "trapezoid": [1e-3, 7e-3], - "trapz": [1e-3, 7e-3], - "vdot": [1e-3, 1e-2], - } - - fp16_low_precision_variant_dict = { - ("nn.functional.interpolate", "trilinear"): [3e-2, 3e-3], - ("nn.functional.interpolate", "linear"): [3e-2, 3e-3], - } - - @common_device_type.ops( - [op for op in OPS_DB if op.name in ALL_OPS_IN_DB], - allowed_dtypes=onnx_test_common.TESTED_DTYPES, - ) - def test_output_match(self, device: str, dtype: torch.dtype, op): - """Test the ONNX exporter.""" - _run_test_output_match(self, device, dtype, op) - - -for opset in onnx_test_common.FX_TESTED_OPSETS: - for model_type in pytorch_test_common.TorchModelType: - # The name needs to match the parameterized_class name. - test_class_name = f"TestOnnxModelOutputConsistency_opset_version_{opset}_model_type_TorchModelType.{model_type.name}" - onnx_test_common.add_decorate_info( - OPS_DB, - test_class_name, - "test_output_match", - opset=opset, - skip_or_xfails=EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES, - ) - - common_device_type.instantiate_device_type_tests( - globals()[test_class_name], globals(), only_for="cpu" - ) - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 747c922a549542..47778a689d8330 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -192,9 +192,7 @@ def forward(self, input): return self.conv2(input) x = torch.randn(20, 16, 50, 50) - onnx_program = dynamo_export( - TraceModel(), x, export_options=ExportOptions(op_level_debug=False) - ) + onnx_program = dynamo_export(TraceModel(), x) assert_has_diagnostics( onnx_program.diagnostic_context, diagnostics.rules.find_opschema_matched_symbolic_function, @@ -202,34 +200,6 @@ def forward(self, input): expected_node="aten.convolution.default", ) - def test_dispatch_overload_fall_back_default_raise_diagnostic_warning(self): - class TraceModel(torch.nn.Module): - def forward(self, input): - return torch.ops.aten.add.Tensor(input, input) - - onnx_registry = torch.onnx.OnnxRegistry() - self.assertTrue( - onnx_registry.is_registered_op( - namespace="aten", op_name="add", overload="Tensor" - ) - ) - - aten_add_Tensor = registration.OpName.from_name_parts( - namespace="aten", op_name="add", overload="Tensor" - ) - onnx_registry._registry.pop(aten_add_Tensor) - - x = torch.tensor(3) - onnx_program = dynamo_export( - TraceModel(), x, export_options=ExportOptions(onnx_registry=onnx_registry) - ) - assert_has_diagnostics( - onnx_program.diagnostic_context, - diagnostics.rules.find_operator_overloads_in_onnx_registry, - diagnostics.levels.WARNING, - expected_node="aten.add.Tensor", - ) - def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self): class CustomModule(torch.nn.Module): def forward(self, input): @@ -272,20 +242,6 @@ def forward(self, input): export_options=torch.onnx.ExportOptions(onnx_registry=registry), ) - try: - torch.onnx.dynamo_export( - TraceModel(), - x, - export_options=torch.onnx.ExportOptions(onnx_registry=registry), - ) - except torch.onnx.OnnxExporterError as e: - assert_has_diagnostics( - e.onnx_program.diagnostic_context, - diagnostics.rules.no_symbolic_function_for_call_function, - diagnostics.levels.ERROR, - expected_node="aten.mul.Tensor", - ) - def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info( self, ): @@ -574,34 +530,6 @@ def test_fake_tensor_mode_huggingface_tiiuae_falcon(self): onnx.checker.check_model(onnx_program.model_proto) onnx.shape_inference.infer_shapes(onnx_program.model_proto) - def test_exported_program_input_with_custom_fx_tracer(self): - from torch.onnx._internal import _exporter_legacy - from torch.onnx._internal.fx import dynamo_graph_extractor - - class Model(torch.nn.Module): - def forward(self, x): - return x + 1 - - x = torch.randn(1, 1, 2) - exported_program = torch.export.export(Model(), args=(x,)) - - export_options = torch.onnx.ExportOptions() - export_options = _exporter_legacy.ResolvedExportOptions( - export_options, model=exported_program - ) - export_options.fx_tracer = ( - dynamo_graph_extractor.DynamoExport() - ) # Override fx_tracer to an unsupported tracer - with self.assertRaises(torch.onnx.OnnxExporterError): - onnx_program = torch.onnx.dynamo_export( - exported_program, - x, - export_options=export_options, - ) - self.assertTrue(onnx_program._export_exception is not None) - with self.assertRaises(torch.onnx.InvalidExportOptionsError): - raise self._export_exception - def test_exported_program_torch_distributions_normal_Normal(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -636,21 +564,6 @@ def forward(self, input): # with no Cast node in between. self.assertEqual(div_node.input[0], model_proto.graph.input[0].name) - def test_exported_program_as_input_with_model_signature(self): - class Model(torch.nn.Module): - def forward(self, x): - return x + 1.0 - - x = torch.randn(1, 1, 2, dtype=torch.float) - exported_program = torch.export.export(Model(), args=(x,)) - - onnx_program = torch.onnx.dynamo_export( - exported_program, - x, - ) - - self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature) - @common_utils.parametrize( "float8_type", [ @@ -818,11 +731,11 @@ def forward(self, tensor_x: torch.Tensor): onnx_program = torch.onnx.dynamo_export( model, tensor_x, export_options=export_options ) + onnx_program.apply_weights(state_dict) with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file: onnx_program.save( tmp_onnx_file.name, include_initializers=include_initializer, - model_state=state_dict if include_initializer else None, ) onnx_model = onnx.load(tmp_onnx_file.name) self.assertEqual( diff --git a/test/onnx/test_fx_to_onnx_decomp_skip.py b/test/onnx/test_fx_to_onnx_decomp_skip.py index db8edce1425949..466ee4a0bb956f 100644 --- a/test/onnx/test_fx_to_onnx_decomp_skip.py +++ b/test/onnx/test_fx_to_onnx_decomp_skip.py @@ -19,12 +19,6 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str): class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): - def _test_exported_program_forces_decomposition(self, model, input, op_type): - ep = torch.export.export(model, input) - onnx_program = torch.onnx.dynamo_export(ep, *input) - with self.assertRaises(AssertionError): - assert_op_in_onnx_model(onnx_program.model_proto, op_type) - def test_upsample_bilinear2d(self): class TestModel(torch.nn.Module): def __init__(self) -> None: @@ -37,9 +31,6 @@ def forward(self, x): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2),), "Resize" - ) def test_upsample_bilinear2d_output_size(self): def func(x: torch.Tensor): @@ -61,9 +52,6 @@ def forward(self, x): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize" - ) def test_upsample_trilinear3d_output_size(self): def func(x: torch.Tensor): @@ -82,9 +70,6 @@ def forward(self, x): # If decomposition is skipped, the model will contain an InstanceNormalization op # instead of BatchNormalization op w/ training=True. assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization" - ) if __name__ == "__main__": diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 150edfef3412d1..ff4d3a91bd1afa 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -45,10 +45,7 @@ def _parameterized_class_attrs_and_values(): input_values.extend( itertools.product( (True, False), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), + (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), ) ) return { @@ -478,6 +475,9 @@ def forward(self, x): MutationModel(), (torch.randn(12),), has_mutation=True ) + @unittest.skip( + "Fixme: arange in torchlib does not support dynamic start and end yet." + ) def test_arange(self): class ArangeModel(torch.nn.Module): def forward(self, input): @@ -909,10 +909,7 @@ def _parameterized_class_attrs_and_values_with_fake_options(): (True, False), (True, False), (True, False), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), + (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), ) ) return { @@ -983,13 +980,6 @@ def _test_fake_tensor_mode_exporter( # Create the toy model with real weight. real_model = create_model() state_dict = real_model.state_dict() # concrete (non-fake) state_dict - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - real_model = torch.export.export( - real_model, args=create_args(), kwargs=create_kwargs() - ) with tempfile.NamedTemporaryFile( prefix=model_name, suffix=".pt" @@ -1012,13 +1002,6 @@ def _test_fake_tensor_mode_exporter( ) if export_within_fake_mode: - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - fake_model = torch.export.export( - fake_model, args=fake_args, kwargs=fake_kwargs - ) onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, @@ -1027,13 +1010,6 @@ def _test_fake_tensor_mode_exporter( ) if not export_within_fake_mode: - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - fake_model = torch.export.export( - fake_model, args=fake_args, kwargs=fake_kwargs - ) onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, **fake_kwargs, export_options=export_options ) @@ -1090,10 +1066,6 @@ def _test_fake_tensor_mode_exporter( for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close(ref_output, torch.tensor(ort_output)) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_simple(self): def create_model() -> nn.Module: class Model(torch.nn.Module): @@ -1123,10 +1095,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="!(it.GetName().empty())", reason="With after onnx==1.16, constant folding in optimizer causes this error.", @@ -1163,10 +1131,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_large_scale_exporter_with_toy_mlp(self): class MLPModel(nn.Module): def __init__(self) -> None: @@ -1205,10 +1169,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_huggingface_google_t5(self): config = transformers.T5Config( vocab_size=8096, d_model=64, num_layers=2, num_heads=2 @@ -1241,10 +1201,6 @@ def create_model(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", @@ -1307,10 +1263,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_huggingface_mosaicml_mpt(self): config = transformers.MptConfig( vocab_size=8096, d_model=64, n_heads=2, n_layers=3 @@ -1338,10 +1290,6 @@ def create_model(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="SymIntArrayRef expected to contain only concrete integers", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, @@ -1371,10 +1319,6 @@ def create_model(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( error_message="Expected 5 inputs, got 3", reason="https://github.com/pytorch/pytorch/issues/115745", @@ -1414,10 +1358,6 @@ def create_kwargs(): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="SymIntArrayRef expected to contain only concrete integers", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, diff --git a/test/onnx/test_fx_type_promotion.py b/test/onnx/test_fx_type_promotion.py index 1e3860ad2a8efa..fc7dc21fba0069 100644 --- a/test/onnx/test_fx_type_promotion.py +++ b/test/onnx/test_fx_type_promotion.py @@ -1,22 +1,16 @@ # Owner(s): ["module: onnx"] -import pytorch_test_common - from torch.onnx._internal.fx.passes import type_promotion from torch.testing._internal import common_utils class TestGeneratedTypePromotionRuleSet(common_utils.TestCase): - @pytorch_test_common.skip_in_ci( - "Reduce noise in CI. " - "The test serves as a tool to validate if the generated rule set is current. " - ) def test_generated_rule_set_is_up_to_date(self): generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET - latest_set = ( - type_promotion.TypePromotionRuleSetGenerator.generate_from_torch_refs() - ) + latest_set = type_promotion.ElementwiseTypePromotionRuleSetGenerator.generate_from_torch_refs() + # Please update the list in torch/onnx/_internal/fx/passes/type_promotion.py following the instruction + # if this test fails self.assertEqual(generated_set, latest_set) def test_initialize_type_promotion_table_succeeds(self): diff --git a/test/onnx/test_lazy_import.py b/test/onnx/test_lazy_import.py new file mode 100644 index 00000000000000..bce4b892095ec8 --- /dev/null +++ b/test/onnx/test_lazy_import.py @@ -0,0 +1,37 @@ +# Owner(s): ["module: onnx"] + +import subprocess +import sys +import tempfile + +import pytorch_test_common + +from torch.testing._internal import common_utils + + +class TestLazyONNXPackages(pytorch_test_common.ExportTestCase): + def _test_package_is_lazily_imported(self, pkg, torch_pkg="torch.onnx"): + with tempfile.TemporaryDirectory() as wd: + r = subprocess.run( + [sys.executable, "-Ximporttime", "-c", "import torch.onnx"], + capture_output=True, + text=True, + cwd=wd, + check=True, + ) + + # The extra space makes sure we're checking the package, not any package containing its name. + self.assertTrue( + f" {pkg}" not in r.stderr, + f"`{pkg}` should not be imported, full importtime: {r.stderr}", + ) + + def test_onnxruntime_is_lazily_imported(self): + self._test_package_is_lazily_imported("onnxruntime") + + def test_onnxscript_is_lazily_imported(self): + self._test_package_is_lazily_imported("onnxscript") + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 64cbba6fc15fc5..37ca3836e53876 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -600,33 +600,6 @@ def forward(self, x, y): + ")." ) - def test_onnx_checker_invalid_graph(self): - class CustomAddModule(torch.nn.Module): - def forward(self, x, y): - return torch.add(x, y) - - def symbolic_custom_invalid_add(g, input, other, alpha=None): - return g.op("Add", input, other, invalid_attr_i=1) - - torch.onnx.register_custom_op_symbolic( - "::add", symbolic_custom_invalid_add, opset_version=9 - ) - - x = torch.randn(2, 3, 4) - y = torch.randn(2, 3, 4) - - test_model = CustomAddModule() - f = io.BytesIO() - - try: - with self.assertRaises(torch.onnx.errors.CheckerError): - torch.onnx.export(test_model, (x, y), f, opset_version=9) - finally: - torch.onnx.unregister_custom_op_symbolic("::add", 9) - - self.assertTrue(f.getvalue(), "ONNX graph was not exported.") - loaded_model = onnx.load_from_string(f.getvalue()) - def test_shape_value_map(self): class RSoftMax(torch.nn.Module): def __init__(self, radix, cardinality): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 5d2dc4bdf29f8a..c812b8e18b30c3 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -12534,6 +12534,27 @@ def forward(self, x, y): self.run_test(M(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(14) + def test_scaled_dot_product_attention(self): + class M(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=1.0 + ) + + # Parameters + batch_size = 2 # Number of samples in the batch + num_heads = 4 # Number of attention heads + seq_length = 5 # Sequence length + head_dim = 8 # Dimensionality of each head + + # Create random query, key, and value tensors + q = torch.randn(batch_size, num_heads, seq_length, head_dim) + k = torch.randn(batch_size, num_heads, seq_length, head_dim) + v = torch.randn(batch_size, num_heads, seq_length, head_dim) + + self.run_test(M(), (q, k, v)) + @skipScriptTest() @skipIfUnsupportedMinOpsetVersion(11) def test_dist_normal(self): @@ -13582,7 +13603,7 @@ def forward(self, input, grid): ): if self.opset_version < 20: with self.assertRaises( - torch.onnx.errors.OnnxExporterError, + torch.onnx.OnnxExporterError, ): self.run_test( GridSampleModule(mode, padding_mode, align_corners), diff --git a/test/onnx/torch_export/test_torch_export_with_onnxruntime.py b/test/onnx/torch_export/test_torch_export_with_onnxruntime.py index 7e3b24874e003b..7e7cf24a84e82a 100644 --- a/test/onnx/torch_export/test_torch_export_with_onnxruntime.py +++ b/test/onnx/torch_export/test_torch_export_with_onnxruntime.py @@ -36,14 +36,15 @@ def _compare_onnx_and_torch_exported_program( torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) else: torch_outputs = torch_exported_program(*input_args, **input_kwargs) - torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( - torch_outputs - ) - if len(torch_outputs_onnx_format) != len(onnx_outputs): + + if isinstance(torch_outputs, torch.Tensor): + torch_outputs = [torch_outputs] + + if len(torch_outputs) != len(onnx_outputs): raise AssertionError( - f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" + f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}" ) - for torch_output, onnx_output in zip(torch_outputs_onnx_format, onnx_outputs): + for torch_output, onnx_output in zip(torch_outputs, onnx_outputs): torch.testing.assert_close( torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol ) diff --git a/test/profiler/test_cpp_thread.cpp b/test/profiler/test_cpp_thread.cpp index 09d2fb42a2449b..ce60d9c816c697 100644 --- a/test/profiler/test_cpp_thread.cpp +++ b/test/profiler/test_cpp_thread.cpp @@ -1,5 +1,5 @@ -#include +#include // @manual #include #include diff --git a/test/profiler/test_cpp_thread.py b/test/profiler/test_cpp_thread.py index 109831b6634c2e..5dd12277e181be 100644 --- a/test/profiler/test_cpp_thread.py +++ b/test/profiler/test_cpp_thread.py @@ -26,7 +26,7 @@ def is_fbcode(): if is_fbcode(): - import caffe2.test.profiler_test_cpp_thread_lib as cpp + import caffe2.test.profiler_test_cpp_thread_lib as cpp # @manual=//caffe2/test:profiler_test_cpp_thread_lib else: # cpp extensions use relative paths. Those paths are relative to # this file, so we'll change the working directory temporarily diff --git a/test/dynamo_expected_failures/TestScript.test_is_after_use b/test/profiler/test_cpp_thread_lib.pyi similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_is_after_use rename to test/profiler/test_cpp_thread_lib.pyi diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index cf92a08d836e10..f4f4e2e99270a4 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -990,19 +990,41 @@ def trace_handler(p): # print(output) test_schedule = torch.profiler.schedule( - skip_first=2, wait=1, warmup=1, active=2, repeat=2 + skip_first=3, wait=2, warmup=1, active=4, repeat=2 ) test_schedule_expected_outputs = [ + # skip first 3 ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, + # ---- + # repeat No. 1 begin + # wait 2 + ProfilerAction.NONE, + ProfilerAction.NONE, + # warmup 1 ProfilerAction.WARMUP, + # active 2 begin + ProfilerAction.RECORD, + ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, + # active 2 end + # repeat No. 1 end + # --- + # repeat No. 2 begin + # wait 2 + ProfilerAction.NONE, ProfilerAction.NONE, + # warmup 1 ProfilerAction.WARMUP, + # active 2 begin + ProfilerAction.RECORD, + ProfilerAction.RECORD, ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE, + # active 2 end + # repeat No. 2 end ProfilerAction.NONE, ProfilerAction.NONE, ProfilerAction.NONE, @@ -1644,7 +1666,12 @@ def test_profiler_op_event_kwargs(self): cm = torch._C._profiler._RecordFunctionFast( "add_test_kwinputs", [x, y], - {"stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"'}, + { + "stream": 0, + "grid": "lambda x : x + 1", + "debug": 'debug"', + "boolean": True, + }, ) for _ in range(4): with cm: @@ -1658,12 +1685,38 @@ def test_profiler_op_event_kwargs(self): ] for e in op_events: if e["name"] == "add_test_kwinputs": + print(e["args"]) args = e["args"] self.assertTrue("stream" in args) self.assertTrue("grid" in args) - self.assertTrue(args["stream"] == "0") + self.assertTrue("boolean" in args) + self.assertTrue(args["stream"] == 0) self.assertTrue(args["grid"] == "lambda x : x + 1") self.assertTrue(args["debug"] == "None") + self.assertTrue(args["boolean"]) + + with profile(record_shapes=True) as p1: + cm = torch._C._profiler._RecordFunctionFast( + "add_test_kwinputs", + [x, y], + {"stream": "test", "grid": [1, 2]}, + ) + for _ in range(4): + with cm: + x.add(y) + with TemporaryFileName(mode="w+") as fname1: + p1.export_chrome_trace(fname1) + with open(fname1) as f1: + j = json.load(f1) + op_events = [ + e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" + ] + for e in op_events: + if e["name"] == "add_test_kwinputs": + print(e["args"]) + args = e["args"] + self.assertTrue("stream" not in args) + self.assertTrue("grid" not in args) def test_is_profiler_enabled(self): self.assertFalse(torch.autograd.profiler._is_profiler_enabled) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index b432f770a0abb7..999b63dbbb0a65 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4223,8 +4223,8 @@ def test_wrapped_quantized_linear(self, m, n, k): ret_ref = qlinear.dequantize() self.assertEqual(ret, ret_ref) - """Tests the correctness of the _quantized::wrapped_linear_prepack and - _quantized::wrapped_quantized_linear_prepacked ops.""" + """Tests the correctness of the _quantized::_wrapped_linear_prepack and + _quantized::_wrapped_quantized_linear_prepacked ops.""" @skipIfNoFBGEMM @given( m=st.integers(2, 6), @@ -4243,13 +4243,13 @@ def test_wrapped_quantized_linear_prepacked(self, m, n, k): output_zero_point = torch.tensor(0) out_channel = n - ret_1 = torch.ops._quantized.wrapped_linear_prepack( + ret_1 = torch.ops._quantized._wrapped_linear_prepack( weight, weight_scale, weight_zero_point, bias ) - ret_2 = torch.ops._quantized.wrapped_quantized_linear_prepacked( + ret_2 = torch.ops._quantized._wrapped_quantized_linear_prepacked( input, input_scale, input_zero_point, diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 168c00920f623a..964dc051c3a434 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -67,14 +67,30 @@ NP_RANDOM_SEED = 19 tolerance = 1e-6 +# copy and modified from torch/ao/quantization/observer.py +_INT_DTYPES = ( + torch.qint8, + torch.quint8, + torch.quint4x2, + torch.qint32, + torch.int8, + torch.uint8, + torch.int16, + torch.int32, + torch.uint16, +) + class TestObserver(QuantizationTestCase): - @given(qdtype=st.sampled_from((torch.qint8, torch.quint8, torch.qint32)), + @given(qdtype=st.sampled_from(_INT_DTYPES), qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), reduce_range=st.booleans()) def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if (qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric) or qdtype == torch.qint32: reduce_range = False + if qdtype == torch.quint4x2: + return + ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range), MovingAverageMinMaxObserver(averaging_constant=0.5, dtype=qdtype, @@ -82,18 +98,23 @@ def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): reduce_range=reduce_range)] def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val): + assert dtype in _INT_DTYPES, "Not supported dtype: {dtype}, supported dtypes are {_INT_DTYPES}" eps = torch.tensor([tolerance]) - if dtype == torch.qint8: + if dtype in [torch.qint8, torch.int8]: if reduce_range: quant_min, quant_max = -64, 63 else: quant_min, quant_max = -128, 127 - elif dtype == torch.quint8: + elif dtype in [torch.quint8, torch.uint8]: if reduce_range: quant_min, quant_max = 0, 127 else: quant_min, quant_max = 0, 255 - elif dtype == torch.qint32: + elif dtype == torch.int16: + quant_min, quant_max = -1 * (2 ** 15), (2 ** 15) - 1 + elif dtype == torch.uint16: + quant_min, quant_max = 0, (2 ** 16) - 1 + elif dtype in [torch.qint32, torch.int32]: quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1 min_val_neg = torch.tensor([0.]) @@ -103,12 +124,15 @@ def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val) if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: scale = torch.max(-min_val_neg, max_val_pos) / (float(quant_max - quant_min) / 2) scale = torch.max(scale, eps) - if dtype == torch.quint8: + if dtype in [torch.quint8, torch.uint8]: zero_point = 128 + if dtype in [torch.uint16]: + zero_point = 2 ** 15 else: scale = torch.max((max_val_pos - min_val_neg) / float(quant_max - quant_min), eps) zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale, zero_point for myobs in ObserverList: diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 905098c3e6aca0..e2b7236d2ef093 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -4,7 +4,6 @@ from typing import Any, Dict import torch -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, @@ -24,6 +23,7 @@ OP_TO_ANNOTATOR, QuantizationConfig, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import QuantizationTestCase from torch.testing._internal.common_utils import IS_WINDOWS @@ -100,10 +100,10 @@ def _test_duplicate_dq( # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 5f2d1e2d3cf574..21488e8cacdd42 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -13,7 +13,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.testing._internal.common_utils import IS_WINDOWS +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef class TestHelperModules: @@ -99,10 +99,10 @@ def _test_metadata_porting( # program capture m = copy.deepcopy(m_eager) - m = torch._export.capture_pre_autograd_graph( + m = torch.export.export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) # Calibrate @@ -139,6 +139,8 @@ def _test_metadata_porting( self.assertEqual(v, node_tags[k]) return m + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_simple_metadata_porting(self): """ Model under test diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 2a67f6e23097fd..7808eb892579e9 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -9,6 +9,7 @@ from torch._export import capture_pre_autograd_graph from torch.ao.quantization import ( compare_results, + CUSTOM_KEY, extract_results_from_loggers, generate_numeric_debug_handle, NUMERIC_DEBUG_HANDLE_KEY, @@ -19,16 +20,22 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import TestHelperModules -from torch.testing._internal.common_utils import IS_WINDOWS, TestCase +from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]: debug_handle_map: Dict[torch.fx.Node, int] = {} for node in model.graph.nodes: - if NUMERIC_DEBUG_HANDLE_KEY in node.meta: - debug_handle_map[str(node)] = node.meta[NUMERIC_DEBUG_HANDLE_KEY] + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ + NUMERIC_DEBUG_HANDLE_KEY + ] return debug_handle_map @@ -47,8 +54,8 @@ def test_simple(self): unique_ids = set() count = 0 for n in m.graph.nodes: - if NUMERIC_DEBUG_HANDLE_KEY in n.meta: - unique_ids.add(n.meta[NUMERIC_DEBUG_HANDLE_KEY]) + if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]: + unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]) count += 1 self.assertEqual(len(unique_ids), count) @@ -71,7 +78,7 @@ def test_quantize_pt2e_preserve_handle(self): res_counter = Counter(debug_handle_map.values()) repeated_debug_handle_ids = [2, 3, 6] # 3 ids were repeated because we copy over the id from node to its output observer - # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim, torch.ops.aten.conv1d.default + # torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default for dh_id in repeated_debug_handle_ids: self.assertEqual(res_counter[dh_id], 2) @@ -110,19 +117,37 @@ def test_deepcopy_preserve_handle(self): self.assertEqual(debug_handle_map, debug_handle_map_ref) - @unittest.skip("All nodes' meta are preserved but get_attr nodes' meta are wrong.") + @skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack + # trace of the mode torch function impl doesn't match the traced graph stored lineno. def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() generate_numeric_debug_handle(m) debug_handle_map_ref = _extract_debug_handles(m) - m_export = capture_pre_autograd_graph(m, example_inputs) + m_export = export_for_training(m, example_inputs).module() debug_handle_map = _extract_debug_handles(m_export) self.assertEqual(debug_handle_map, debug_handle_map_ref) + def test_run_decompositions_preserve_handle(self): + m = TestHelperModules.Conv2dThenConv1d() + example_inputs = m.example_inputs() + m = torch.export.export(m, example_inputs) + generate_numeric_debug_handle(m) + + debug_handle_map_ref = _extract_debug_handles(m) + + m_copy = copy.copy(m) + m_copy = m_copy.run_decompositions() + debug_handle_map = _extract_debug_handles(m_copy) + + # checking the map still has the same ids, the node may change + self.assertEqual( + set(debug_handle_map.values()), set(debug_handle_map_ref.values()) + ) + def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() @@ -135,7 +160,7 @@ def test_prepare_for_propagation_comparison(self): from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)] - self.assertEqual(len(loggers), 8) + self.assertEqual(len(loggers), 7) self.assertTrue("conv2d" in [logger.node_name for logger in loggers]) self.assertEqual(res, ref) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 767eb00b358659..06a3c56f16dbec 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -4,6 +4,7 @@ import torch from torch import Tensor from torch._export import capture_pre_autograd_graph +from torch._utils_internal import capture_pre_autograd_graph_using_training_ir from torch.ao.quantization import observer, ObserverOrFakeQuantize, QConfigMapping from torch.ao.quantization.qconfig import ( default_per_channel_symmetric_qnnpack_qconfig, @@ -599,6 +600,14 @@ def validate(self, model: torch.fx.GraphModule) -> None: def test_fixed_qparams_qspec_observer_dedup(self): class BackendAQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = FixedQParamsQuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + scale=1.0 / 256.0, + zero_point=0, + ) for node in model.graph.nodes: if ( node.op == "call_function" @@ -606,14 +615,6 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ): input_act = node.args[0] assert isinstance(input_act, Node) - act_qspec = FixedQParamsQuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - scale=1.0 / 256.0, - zero_point=0, - ) node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ input_act: act_qspec, @@ -629,13 +630,6 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: assert isinstance(input_act, Node) input_act1 = node.args[1] assert isinstance(input_act, Node) - act_qspec = QuantizationSpec( - observer_or_fake_quant_ctr=observer.default_observer, - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - ) node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ input_act0: act_qspec, @@ -1309,7 +1303,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = capture_pre_autograd_graph(m, example_inputs) + m = capture_pre_autograd_graph(m, (example_inputs,)) with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1452,7 +1446,6 @@ def forward(self, x): for n in m.graph.nodes: if n.op == "get_attr" and "frozen_param" in n.target: - self.assertIn("stack_trace", n.meta) for key in n.meta: self.assertEqual(n.meta[key], weight_meta[key]) @@ -1863,6 +1856,14 @@ def test_move_exported_model_dropout_inplace(self): self._test_move_exported_model_dropout(inplace=True) def _get_bn_train_eval_ops(self): + if capture_pre_autograd_graph_using_training_ir(): + return ( + torch.ops.aten.batch_norm.default, + torch.ops.aten.batch_norm.default, + ) + # TODO: This branch is going through a deprecated branch and should be deleted soon, + # after capture_pre_autograd_graph fully migrate to training IR + # T199018392 if TEST_WITH_ROCM: return ( torch.ops.aten.miopen_batch_norm.default, diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index c0dc5fb32e82d1..2a9aea449bc71a 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -604,6 +604,9 @@ def test_prepare_qat_conv_bn_fusion_getitem_placeholder(self): is returned as part of the match anyway (as a placeholder). """ + if capture_pre_autograd_graph_using_training_ir(): + self.skipTest("Not applicable to training IR") + class M(torch.nn.Module): def __init__(self, conv_class, bn_class): super().__init__() diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index aa254762db114f..8b357bd9b311ae 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -7,7 +7,6 @@ import torch import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.nn as nn -from torch._export import capture_pre_autograd_graph from torch.ao.quantization import ObserverBase from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, @@ -18,6 +17,7 @@ QUANT_ANNOTATION_KEY, X86InductorQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -550,10 +550,10 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -1222,7 +1222,7 @@ def _test_linear_unary_helper( Test pattern of linear with unary post ops (e.g. relu) with X86InductorQuantizer. """ use_bias_list = [True, False] - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_list = [False] if post_op_algo_list is None: post_op_algo_list = [None] @@ -1362,7 +1362,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op is supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_add_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -1466,7 +1466,7 @@ def test_linear_binary2(self): Since linear_1 has 2 users, we should annotate linear_2 for binary fusion instead of linear_1 """ example_inputs = (torch.randn(2, 16),) - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_add_list = [False] is_qat_list = [False, True] is_dynamic_list = [False, True] @@ -1535,9 +1535,9 @@ def _test_linear_binary_unary_helper(self, is_qat=False, is_dynamic=False): Currently, only add as binary post op and relu as unary post op are supported. """ linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both] - # TODO test for inplace add after refactoring of capture_pre_autograd_graph + # TODO test for inplace add after refactoring of export_for_training inplace_add_list = [False] - # TODO test for inplace relu after refactoring of capture_pre_autograd_graph + # TODO test for inplace relu after refactoring of export_for_training inplace_relu_list = [False] example_inputs = (torch.randn(2, 16),) quantizer = X86InductorQuantizer().set_global( @@ -2110,7 +2110,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 5e850a68482886..12962c8f3b0099 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -4,7 +4,7 @@ import torch import torch._dynamo as torchdynamo -from torch._export import capture_pre_autograd_graph +from torch._utils_internal import capture_pre_autograd_graph_using_training_ir from torch.ao.ns.fx.utils import compute_sqnr from torch.ao.quantization import ( default_dynamic_fake_quant, @@ -30,6 +30,7 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, PT2EQuantizationTestCase, @@ -361,7 +362,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = capture_pre_autograd_graph(m, example_inputs) + m = export_for_training(m, example_inputs).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -497,10 +498,10 @@ def test_propagate_annotation(self): example_inputs = (torch.randn(1, 3, 5, 5),) # program capture - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -680,6 +681,20 @@ def test_dynamic_linear_with_conv(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, } + + capture_pre_autograd_graph_node_occurrence = None + if capture_pre_autograd_graph_using_training_ir(): + capture_pre_autograd_graph_node_occurrence = { + # input and output are using quantize_per_tensor and weight is using quantize_per_channel + # In training IR, the decomposition is different. + # `torch.ops.quantized_decomposed.quantize_per_tensor.default` nodes becomes + # `torch.ops.quantized_decomposed.quantize_per_tensor.tensor` nodes. + torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, + } act_affine_quant_obs = observer.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, @@ -703,6 +718,7 @@ def test_dynamic_linear_with_conv(self): [], True, qconfig_mapping, + capture_pre_autograd_graph_node_occurrence=capture_pre_autograd_graph_node_occurrence, ) def test_gru(self): @@ -754,10 +770,10 @@ def forward(self, input_tensor, hidden_tensor): model_fx = _convert_to_reference_decomposed_fx(model_fx) with torchdynamo.config.patch(allow_rnn=True): - model_graph = capture_pre_autograd_graph( + model_graph = export_for_training( model_graph, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=False @@ -818,10 +834,10 @@ def forward(self, input_tensor, hidden_tensor): model_fx = _convert_to_reference_decomposed_fx(model_fx) with torchdynamo.config.patch(allow_rnn=True): - model_graph = capture_pre_autograd_graph( + model_graph = export_for_training( model_graph, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=False @@ -995,10 +1011,10 @@ def test_resnet18(self): m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = capture_pre_autograd_graph( + m = export_for_training( m, example_inputs, - ) + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/test/run_test.py b/test/run_test.py index 80a724e129a7a2..231a1b2b7ca012 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -573,13 +573,8 @@ def run_test( def try_set_cpp_stack_traces(env, command, set=True): # Print full c++ stack traces during retries - # Don't do it for macos inductor tests as it makes them - # segfault for some reason - if not ( - IS_MACOS and len(command) >= 2 and command[2].startswith(INDUCTOR_TEST_PREFIX) - ): - env = env or {} - env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0" + env = env or {} + env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0" return env @@ -1422,7 +1417,7 @@ def get_selected_tests(options) -> List[str]: options.exclude.extend(CPP_TESTS) if options.mps: - selected_tests = ["test_mps", "test_metal", "test_modules"] + selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"] else: # Exclude all mps tests otherwise options.exclude.extend(["test_mps", "test_metal"]) diff --git a/test/test_autocast.py b/test/test_autocast.py index 24f87944990d8f..1b25750b404c00 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -1,10 +1,12 @@ # Owner(s): ["module: unknown"] -import collections import unittest import torch -from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists +from torch.testing._internal.autocast_test_lists import ( + AutocastCPUTestLists, + TestAutocast, +) from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, @@ -14,7 +16,7 @@ from torch.utils._python_dispatch import TorchDispatchMode -class TestAutocastCPU(TestCase): +class TestAutocastCPU(TestAutocast): def setUp(self): super().setUp() self.autocast_lists = AutocastCPUTestLists(torch.device("cpu")) @@ -23,100 +25,6 @@ def tearDown(self): del self.autocast_lists super().tearDown() - def _run_autocast_outofplace( - self, - op, - args, - run_as_type, - out_type=None, - module=torch, - add_kwargs=None, - amp_dtype=torch.bfloat16, - ): - # helper to cast args - def cast(val, to_type): - if isinstance(val, torch.Tensor): - return val.to(to_type) if val.is_floating_point() else val - elif isinstance(val, collections.abc.Iterable): - return type(val)(cast(v, to_type) for v in val) - else: - return val - - if add_kwargs is None: - add_kwargs = {} - - self.assertFalse(torch.is_autocast_cpu_enabled()) - with torch.cpu.amp.autocast(dtype=amp_dtype): - self.assertTrue(torch.is_autocast_cpu_enabled()) - out_type = out_type if out_type is not None else run_as_type - output = output_method = None - - # Try module.* variant, if requested: - if module is not None and hasattr(module, op): - output = getattr(module, op)(*args, **add_kwargs) - if isinstance(output, torch.Tensor): - self.assertTrue( - out_type == output.dtype, - f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", - ) - # Try Tensor.* variant: - if hasattr(torch.Tensor, op): - output_method = getattr(args[0], op)(*args[1:], **add_kwargs) - if isinstance(output_method, torch.Tensor): - self.assertTrue( - out_type == output_method.dtype, - f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", - ) - - self.assertTrue( - (output is not None) or (output_method is not None), - f"{op} not found as an attribute on either Tensor or the requested module {module}", - ) - - # Accounts for ops that return Tensors, iterables, and other non-Tensors. - # For example, lstm_cell returns a tuple and equal returns bool. - def compare(first, second): - if isinstance(first, torch.Tensor): - return torch.equal(first, second) - elif isinstance(first, collections.abc.Iterable): - return all(compare(f, s) for f, s in zip(first, second)) - else: - return first == second - - # If both torch.* and Tensor.* variants were found, check outputs are identical - if (output is not None) and (output_method is not None): - self.assertTrue(type(output) == type(output_method)) - comparison = compare(output, output_method) - self.assertTrue( - comparison, f"torch.{op} result did not match Tensor.{op} result" - ) - - # Compare numerics to Python-side "autocasting" that (we expect) does the same thing - # as the C++-side autocasting, and should be bitwise accurate. - output_to_compare = output if output is not None else output_method - with torch.cpu.amp.autocast(enabled=False): - self.assertFalse(torch.is_autocast_cpu_enabled()) - - if module is not None and hasattr(module, op): - control = getattr(module, op)( - *cast(args, run_as_type), **add_kwargs - ) - else: - control = getattr(args[0].to(run_as_type), op)( - *cast(args[1:], run_as_type), **add_kwargs - ) - self.assertTrue(type(output_to_compare) == type(control)) - comparison = compare(output_to_compare, control) - self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_cpu_enabled()) - self.assertFalse(torch.is_autocast_cpu_enabled()) - - def args_maybe_kwargs(self, op_with_args): - if len(op_with_args) == 2: - return op_with_args[0], op_with_args[1], {} - else: - return op_with_args[0], op_with_args[1], op_with_args[2] - @skipIfTorchDynamo() def test_autocast_torch_expect_builtin_promote(self): for ( @@ -125,9 +33,16 @@ def test_autocast_torch_expect_builtin_promote(self): args2, out_type, ) in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type) self._run_autocast_outofplace( - op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16 + op, args1, torch.float32, device="cpu", out_type=out_type + ) + self._run_autocast_outofplace( + op, + args2, + torch.float32, + device="cpu", + out_type=out_type, + amp_dtype=torch.float16, ) @skipIfTorchDynamo() @@ -139,12 +54,13 @@ def test_autocast_methods_expect_builtin_promote(self): out_type, ) in self.autocast_lists.methods_expect_builtin_promote: self._run_autocast_outofplace( - op, args1, torch.float32, module=None, out_type=out_type + op, args1, torch.float32, device="cpu", module=None, out_type=out_type ) self._run_autocast_outofplace( op, args2, torch.float32, + device="cpu", module=None, out_type=out_type, amp_dtype=torch.float16, @@ -155,12 +71,13 @@ def test_autocast_torch_16(self): for op_with_args in self.autocast_lists.torch_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.bfloat16, add_kwargs=maybe_kwargs + op, args, torch.bfloat16, device="cpu", add_kwargs=maybe_kwargs ) self._run_autocast_outofplace( op, args, torch.float16, + device="cpu", add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @@ -170,12 +87,18 @@ def test_autocast_nn_16(self): for op_with_args in self.autocast_lists.nn_16: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs + op, + args, + torch.bfloat16, + device="cpu", + module=torch._C._nn, + add_kwargs=maybe_kwargs, ) self._run_autocast_outofplace( op, args, torch.float16, + device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, amp_dtype=torch.float16, @@ -186,12 +109,13 @@ def test_autocast_torch_fp32(self): for op_with_args in self.autocast_lists.torch_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.float32, add_kwargs=maybe_kwargs + op, args, torch.float32, device="cpu", add_kwargs=maybe_kwargs ) self._run_autocast_outofplace( op, args, torch.float32, + device="cpu", add_kwargs=maybe_kwargs, amp_dtype=torch.float16, ) @@ -201,12 +125,18 @@ def test_autocast_nn_fp32(self): for op_with_args in self.autocast_lists.nn_fp32: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( - op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs + op, + args, + torch.float32, + device="cpu", + module=torch._C._nn, + add_kwargs=maybe_kwargs, ) self._run_autocast_outofplace( op, args, torch.float32, + device="cpu", module=torch._C._nn, add_kwargs=maybe_kwargs, amp_dtype=torch.float16, @@ -215,9 +145,9 @@ def test_autocast_nn_fp32(self): @skipIfTorchDynamo() def test_autocast_torch_need_autocast_promote(self): for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args1, torch.float32) + self._run_autocast_outofplace(op, args1, torch.float32, device="cpu") self._run_autocast_outofplace( - op, args2, torch.float32, amp_dtype=torch.float16 + op, args2, torch.float32, device="cpu", amp_dtype=torch.float16 ) @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") @@ -237,7 +167,7 @@ def test_autocast_rnn(self): m(x, (hx, cx)) # Should be able to run the below case with autocast - with torch.cpu.amp.autocast(): + with torch.amp.autocast(device_type="cpu"): m(x, (hx, cx)) def test_autocast_disabled_with_fp32_dtype(self): @@ -249,7 +179,7 @@ def test_generic_autocast(self): op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) with torch.amp.autocast(device_type="cpu"): generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) - with torch.cpu.amp.autocast(): + with torch.amp.autocast(device_type="cpu"): cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs) self.assertEqual(generic_autocast_output, cpu_autocast_output) @@ -343,11 +273,86 @@ def test_cache_disabled(self): finally: torch._C._set_cached_tensors_enabled(False) + # index_put under AMP follows a cast policy called "promote", + # https://github.com/pytorch/pytorch/blob/4fcd15a667df5b80e81db6563d8d3123a0cbd051/aten/src/ATen/autocast_mode.h#L205-L230 + # That means: + # (1) double precision is ignored, + # (2) if any argument is float, then all arguments are promoted to float, + # (3) if all arguments are of lower precision dtype, then all dtypes must be equal to the same amp autocast dtype. + # Since AMP autocast dtype is thread-local, it is not preserved across thread boundaries during autograd execution, + # and due to the multi-threaded nature of the autograd, the forward pass is being run in bfloat16, while the backward + # pass defaults to float16. The dtype mismatch leads to the error in the policy, as the criteria (3) is not satisfied. + # For more info see https://github.com/pytorch/pytorch/issues/132715. + def test_autocast_prioritize(self): + device = "cuda" + dtype = torch.bfloat16 + + with torch.autocast(device_type=device, enabled=True, dtype=dtype): + t = torch.randn([3, 4, 5], dtype=dtype, device=device, requires_grad=True) + index = torch.randint( + low=0, high=3, size=[3, 4, 5], dtype=torch.int64, device=device + ) + val = torch.randn(1, dtype=dtype, device=device) + + res = torch.index_put(t, [index], val) + + loss = res.mean() + loss.backward() + + +@unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") +class TestAutocastMPS(TestCase): + def test_cast_cache_is_global(self): + class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, w_t): + ctx.save_for_backward(x, w_t) + return torch.nn.functional.linear(x, w_t) + + @staticmethod + def backward(ctx, grad_output): + x, w_t = ctx.saved_tensors + with torch.autocast(device_type="mps"): + dL_dX = torch.matmul(grad_output, w_t) + dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1) + return dL_dX, dL_dW + + data = torch.randn(2, 3).to("mps") + weight = torch.nn.Parameter(torch.randn(4, 3).to("mps")) + weight_dtype_cast_counter = 0 + + class WeightDTypeCastCounterMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if ( + func is torch.ops.aten._to_copy.default + and args[0] is weight + and kwargs["dtype"] is torch.float16 + ): + nonlocal weight_dtype_cast_counter + weight_dtype_cast_counter += 1 + return func(*args, **kwargs) + + def __enter__(self): + # self.old_clear_cache = torch.clear_autocast_cache + # torch.clear_autocast_cache = lambda: None + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + # torch.clear_autocast_cache = self.old_clear_cache + return super().__exit__(exc_type, exc_val, exc_tb) + + with WeightDTypeCastCounterMode(): + with torch.autocast(device_type="mps"): + output = CustomLinear.apply(data, weight) + s = output.sum() + s.backward() + self.assertEqual(weight_dtype_cast_counter, 2) + class TestTorchAutocast(TestCase): def test_autocast_fast_dtype(self): - gpu_fast_dtype = torch.get_autocast_gpu_dtype() - cpu_fast_dtype = torch.get_autocast_cpu_dtype() + gpu_fast_dtype = torch.get_autocast_dtype(device_type="cuda") + cpu_fast_dtype = torch.get_autocast_dtype(device_type="cpu") self.assertEqual(gpu_fast_dtype, torch.half) self.assertEqual(cpu_fast_dtype, torch.bfloat16) diff --git a/test/test_autograd.py b/test/test_autograd.py index a3b5badd69ab76..c7fabb725082c1 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -367,6 +367,7 @@ def backward(ctx, grad): x = torch.ones(1, requires_grad=True) torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward() + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_set_materialize_non_diff_grads(self): class Func(torch.autograd.Function): @staticmethod @@ -1935,6 +1936,50 @@ def prehook2(grad_output): self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) self.assertEqual(counter[0], 1) + def test_node_post_hook_registered_during_unpack_hook(self): + """ + Test that post hooks registered during one of the node's + unpack hooks are properly restricted and will run properly. + """ + test_case = self + + class RegisterPostNodeHook(torch.autograd.graph.saved_tensors_hooks): + def __init__(self) -> None: + def pack_tensor(tensor: torch.Tensor) -> torch.Tensor: + return tensor + + def unpack_tensor(tensor: torch.Tensor) -> torch.Tensor: + node = torch._C._current_autograd_node() + + def hook(outputs, inputs): + # Assert that inputs passed in are None + test_case.assertTrue(all(i is None for i in inputs)) + halved_outputs = tuple( + o / 2.0 if o is not None else None for o in outputs + ) + return halved_outputs + + node.register_hook(hook) + return tensor + + super().__init__(pack_tensor, unpack_tensor) + + a = torch.rand(3, 3, requires_grad=True) + + def model(): + var, mean = torch.var_mean(a, dim=0) + loss = (var + mean).sum() + loss.backward() + + model() + ref_grad = a.grad.clone() + + with RegisterPostNodeHook(): + model() + + # Verify that the post hook got called and the grad propagation worked + self.assertEqual(ref_grad / 2.0 + ref_grad, a.grad) + def test_hooks_cpp(self): # Tests hooks for autograd function implemented in C++ bn = torch.nn.BatchNorm1d(5, affine=False) @@ -3307,6 +3352,7 @@ def backward(ctx, grad_output): y = x.masked_fill(mask, 0) y.sum().backward() + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_mark_non_differentiable_mixed(self): class MyFunction(Function): @staticmethod @@ -4097,6 +4143,7 @@ def assert_strict_equal(var1, var2): assert_strict_equal(xc, x) assert_strict_equal(yc, y) + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_dep_nograd(self): class F1(Function): @staticmethod @@ -8785,6 +8832,7 @@ def vjp(ctx, grad_out): gradcheck(Func.apply, (a,), check_forward_ad=True) + @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") def test_custom_function_forward_mode_non_differentiable(self): # returns differentiable type, marked non-differentiable class Func(torch.autograd.Function): @@ -9459,7 +9507,7 @@ def test_disabling_saved_tensor_hooks_nested(self): self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) - def test_saved_tensor_hooks_custom_error_propagaation(self): + def test_saved_tensor_hooks_custom_error_propagation(self): class CustomError(Exception): pass @@ -9537,7 +9585,7 @@ def unpack_hook(x): self.assertEqual(pack_count, 1) self.assertEqual(unpack_count, 1) - def test_save_tensor_hook_version_counter_not_shared(self): + def test_saved_tensors_hook_version_counter_not_shared(self): class Test(torch.autograd.Function): @staticmethod def forward(ctx, x): @@ -9634,7 +9682,7 @@ def unpack(name): y.sum().backward() self.assertEqual(2 * a, a.grad) - def test_default_saved_variable_hooks_double_backward(self): + def test_default_saved_tensors_hooks_double_backward(self): with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): a = torch.randn(5, requires_grad=True) y = a**3 @@ -9672,7 +9720,7 @@ def test_default_saved_variable_hooks_double_backward(self): # note that in that sense, a is saved twice self.assertEqual(6 * 8 * a, a.grad) - def test_wrapped_number_saved_variable_hooks(self): + def test_wrapped_number_saved_tensors_hooks(self): def err_hook(x): raise RuntimeError("this hook should not be called") diff --git a/test/test_bundled_images.py b/test/test_bundled_images.py index c91814af31a09b..1919e1cd4fe34a 100644 --- a/test/test_bundled_images.py +++ b/test/test_bundled_images.py @@ -4,7 +4,7 @@ import io -import cv2 +import cv2 # @manual import torch import torch.utils.bundled_inputs diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 616a6e0f4b5518..4e86ed458b0788 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization_map_location(self): + def test_open_device_numpy_serialization(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,6 +553,7 @@ def test_open_device_numpy_serialization_map_location(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) + # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -569,6 +570,15 @@ def test_open_device_numpy_serialization_map_location(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) + # Test metadata_only + with TemporaryFileName() as f: + with self.assertRaisesRegex( + RuntimeError, + "Cannot serialize tensors on backends with no storage under skip_data context manager", + ): + with torch.serialization.skip_data(): + torch.save(sd, f) + if __name__ == "__main__": common.run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index e5e25a678e2352..a8e35c1c9a35a9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1,7 +1,7 @@ # Owner(s): ["module: cuda"] -import collections import contextlib +import ctypes import gc import json import os @@ -28,7 +28,7 @@ segment_plot, trace_plot, ) -from torch.testing._internal.autocast_test_lists import AutocastTestLists +from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_cuda import ( _create_scaling_case, _get_torch_cuda_version, @@ -40,7 +40,12 @@ onlyCUDA, onlyNativeDeviceTypes, ) -from torch.testing._internal.common_optimizers import optim_db, optims, TensorTracker +from torch.testing._internal.common_optimizers import ( + _get_optim_inputs_including_global_cliquey_kwargs, + optim_db, + optims, + TensorTracker, +) from torch.testing._internal.common_utils import ( EXPANDABLE_SEGMENTS, freeze_rng_state, @@ -55,7 +60,6 @@ IS_WINDOWS, load_tests, NO_MULTIPROCESSING_SPAWN, - NoTest, parametrize, run_tests, serialTest, @@ -79,10 +83,6 @@ # sharding on sandcastle. This line silences flake warnings load_tests = load_tests -if not TEST_CUDA: - print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 - try: import torchvision.models # noqa: F401 from torchvision.models import resnet18 # noqa: F401 @@ -107,6 +107,7 @@ _cycles_per_ms = None +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest class TestCuda(TestCase): _do_cuda_memory_leak_check = True @@ -115,10 +116,8 @@ class TestCuda(TestCase): def setUp(self): super().setUp() - self.autocast_lists = AutocastTestLists(torch.device("cuda:0")) def tearDown(self): - del self.autocast_lists super().tearDown() @property @@ -151,6 +150,20 @@ def test_pinned_memory_with_cudaregister_multithread(self): for thread in threads: thread.join() + def test_pinned_memory_empty_cache(self): + for alloc_settings in (True, False): + torch.cuda.memory._set_allocator_settings( + f"pinned_use_cuda_host_register:{alloc_settings}" + ) + try: + t = torch.ones(1024 * 1024, pin_memory=True) + self.assertTrue(t.is_pinned()) + del t + torch._C._host_emptyCache() + except RuntimeError as e: + # Some GPUs don't support same address space on host and device side + pass + def test_cudart_register(self): t = torch.ones(20) self.assertFalse(t.is_pinned()) @@ -1538,3323 +1551,2643 @@ def _worker(t): for t in range(num_threads): self.assertEqual(results[t].sum().item(), size * size) - def _run_autocast_outofplace( - self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None - ): - # helper to cast args - def cast(val, to_type): - if isinstance(val, torch.Tensor): - return val.to(to_type) if val.is_floating_point() else val - elif isinstance(val, collections.abc.Iterable): - return type(val)(cast(v, to_type) for v in val) - else: - return val - - if add_kwargs is None: - add_kwargs = {} - fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 - self.assertFalse(torch.is_autocast_enabled()) - with torch.autocast("cuda", dtype=fast_dtype): - self.assertTrue(torch.is_autocast_enabled()) - - out_type = out_type if out_type is not None else run_as_type - output = output_method = None - - # Try module.* variant, if requested: - if module is not None and hasattr(module, op): - output = getattr(module, op)(*args, **add_kwargs) - if isinstance(output, torch.Tensor): - self.assertTrue( - out_type == output.dtype, - f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", - ) + @slowTest + @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") + @serialTest() + def test_max_large_axis(self): + x = torch.zeros(2**32, device="cuda", dtype=torch.int8) + x[-1] = 1 + val, idx = x.max(0) + self.assertEqual(val, 1) + self.assertEqual(idx, x.shape[0] - 1) - # Try Tensor.* variant: - if hasattr(torch.Tensor, op): - output_method = getattr(args[0], op)(*args[1:], **add_kwargs) - if isinstance(output_method, torch.Tensor): - self.assertTrue( - out_type == output_method.dtype, - f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", - ) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_to_numpy(self): + self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) - self.assertTrue( - (output is not None) or (output_method is not None), - f"{op} not found as an attribute on either Tensor or the requested module {module}", - ) + def test_graph_is_current_stream_capturing(self): + self.assertFalse(torch.cuda.is_current_stream_capturing()) - # Accounts for ops that return Tensors, iterables, and other non-Tensors. - # For example, lstm_cell returns a tuple and equal returns bool. - def compare(first, second): - if isinstance(first, torch.Tensor): - return torch.equal(first, second) - elif isinstance(first, collections.abc.Iterable): - return all(compare(f, s) for f, s in zip(first, second)) - else: - return first == second + if TEST_CUDA and (not TEST_WITH_ROCM): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + g = torch.cuda.CUDAGraph() + self.assertFalse(torch.cuda.is_current_stream_capturing()) + g.capture_begin() + self.assertTrue(torch.cuda.is_current_stream_capturing()) + g.capture_end() - # If both torch.* and Tensor.* variants were found, check outputs are identical - if (output is not None) and (output_method is not None): - self.assertTrue(type(output) == type(output_method)) - comparison = compare(output, output_method) - self.assertTrue( - comparison, f"torch.{op} result did not match Tensor.{op} result" - ) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_capture_simple(self): + s = torch.cuda.Stream() - # Compare numerics to Python-side "autocasting" that (we expect) does the same thing - # as the C++-side autocasting, and should be bitwise accurate. - output_to_compare = output if output is not None else output_method - with torch.autocast("cuda", enabled=False): - self.assertFalse(torch.is_autocast_enabled()) + with torch.cuda.stream(s): + a = torch.full((1000,), 1, device="cuda") + g = torch.cuda.CUDAGraph() + torch.cuda.empty_cache() + g.capture_begin() + b = a + for _ in range(10): + b = b + 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - if module is not None and hasattr(module, op): - control = getattr(module, op)( - *cast(args, run_as_type), **add_kwargs - ) - else: - control = getattr(args[0].to(run_as_type), op)( - *cast(args[1:], run_as_type), **add_kwargs - ) - self.assertTrue(type(output_to_compare) == type(control)) - comparison = compare(output_to_compare, control) - self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_enabled()) - self.assertFalse(torch.is_autocast_enabled()) - - def args_maybe_kwargs(self, op_with_args): - if len(op_with_args) == 2: - return op_with_args[0], op_with_args[1], {} - else: - return op_with_args[0], op_with_args[1], op_with_args[2] + g.replay() - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op_with_args in self.autocast_lists.torch_fp16: - skip_test = False - op, args = op_with_args[0], op_with_args[1] - if len(op_with_args) == 3: - skip_test = op_with_args[2] # TEST_WITH_ROCM - if not skip_test: - self._run_autocast_outofplace(op, args, torch.float16) + self.assertTrue(b.sum().item() == 11000.0) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_bf16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op_with_args in self.autocast_lists.torch_fp16: - skip_test = False - op, args = op_with_args[0], op_with_args[1] - if len(op_with_args) == 3: - skip_test = op_with_args[2] # TEST_WITH_ROCM - should_error_from_cudnn = "cudnn" in op and ( - "TORCH_CUDNN_V8_API_DISABLED" in os.environ - and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"]) - or torch.cuda.get_device_capability() < (8, 0) - ) - should_error_from_not_implemented = should_error_from_cudnn - if not skip_test: - if should_error_from_not_implemented: - with self.assertRaises( - RuntimeError, - msg=str(op) + " should not be supported for bfloat16!", - ): - self._run_autocast_outofplace(op, args, torch.bfloat16) - else: - if torch.cuda.is_bf16_supported(): - self._run_autocast_outofplace(op, args, torch.bfloat16) - else: - with self.assertRaisesRegex( - RuntimeError, "Device does not support bfloat16" - ): - self._run_autocast_outofplace(op, args, torch.bfloat16) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graphsafe_set_get_rng_state(self): + # Define a function to create generator states, with optional graph registration + def create_states(generator): + """Initializes generator states and registers them with a CUDA graph if provided.""" + # Ensure the CUDA generator is initialized + torch.rand(1, device="cuda") + generator.manual_seed(0) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_fp32(self): - for op_with_args in self.autocast_lists.torch_fp32: - op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) - self._run_autocast_outofplace( - op, args, torch.float32, add_kwargs=maybe_kwargs - ) + # Save the current state of the generator + old_state = generator.graphsafe_get_state() + # Create and save a cloned state of the generator + new_state = generator.clone_state() + # Return the original generator and its two states + return generator, old_state, new_state - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_need_autocast_promote(self): - for op, args in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args, torch.float32) + def register_states_to_graph(generator_state, graph): + generator, old_state, new_state = generator_state + graph.register_generator_state(old_state) + graph.register_generator_state(new_state) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_torch_expect_builtin_promote(self): - for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + # Define a function to perform specific RNG actions using the generator's states + def perform_random_generation_steps(generator_state): + generator, old_state, new_state = generator_state + random_values = [] - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_nn_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.nn_fp16: - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._nn - ) + # Generate random numbers with the new generator state + generator.graphsafe_set_state(new_state) + random_values.append(torch.rand(5, device="cuda", generator=generator)) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_nn_bf16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.nn_fp16: - if torch.cuda.is_bf16_supported(): - self._run_autocast_outofplace( - op, args, torch.bfloat16, module=torch._C._nn - ) - else: - with self.assertRaisesRegex( - RuntimeError, "Device does not support bfloat16" - ): - self._run_autocast_outofplace( - op, args, torch.bfloat16, module=torch._C._nn - ) + # Generate random numbers twice with the old generator state + generator.graphsafe_set_state(old_state) + random_values.extend( + [torch.rand(5, device="cuda", generator=generator) for _ in range(2)] + ) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_nn_fp32(self): - for op, args in self.autocast_lists.nn_fp32: - self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn) + return random_values - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_linalg_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.linalg_fp16: - self._run_autocast_outofplace( - op, args, torch.float16, module=torch._C._linalg - ) + # Define a function to retrieve the final offsets of the original and new generator states + def get_final_offsets_of_states(generator_state): + generator, old_state, new_state = generator_state + old_state_offset = old_state.get_offset() + new_state_offset = new_state.get_offset() + return old_state_offset, new_state_offset - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_methods_fp16(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.autocast_lists.methods_fp16: - self._run_autocast_outofplace(op, args, torch.float16, module=None) + # Set up and test a new CUDA generator + generator = torch.Generator(device="cuda") + generator_state = create_states(generator) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_methods_fp32(self): - for op, args in self.autocast_lists.methods_fp32: - self._run_autocast_outofplace(op, args, torch.float32, module=None) + # Set up and test the default CUDA generator with a CUDA Graph + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + default_generator = torch.cuda.default_generators[0] + default_generator_state = create_states(default_generator) + register_states_to_graph(default_generator_state, g) - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_methods_expect_builtin_promote(self): - for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: - self._run_autocast_outofplace( - op, args, torch.float32, module=None, out_type=out_type + # Perform random number generation within a CUDA graph + with torch.cuda.stream(s): + g.capture_begin() + graphed_random_values = perform_random_generation_steps( + default_generator_state ) + g.capture_end() - def test_autocast_banned(self): - with torch.autocast("cuda"): - for op, args, module in self.autocast_lists.banned: - with self.assertRaises(RuntimeError): - getattr(module, op)(*args) - - def test_autocast_ignored_types(self): - with torch.autocast("cuda"): - for ignore_type in (torch.double, torch.int32): - a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") - b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") - c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0") - - # Tests if CastPolicy::fp16 ops ignore double and int - # Currently, no ops belonging to this policy support integer inputs. - if ignore_type is torch.double: - with self.assertRaises(RuntimeError): - torch.mm(a_ignore, c_16) - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.mm(a_ignore, b_ignore).dtype - self.assertTrue( - torch.mm(a_ignore, b_ignore).dtype is type_no_autocast - ) + # Synchronize the streams and replay the graph + torch.cuda.current_stream().wait_stream(s) + for _ in range(3): + random_values = perform_random_generation_steps(generator_state) + g.replay() + offset = get_final_offsets_of_states(generator_state) + graph_offset = get_final_offsets_of_states(default_generator_state) - # Tests if CastPolicy::fp32 ops ignore double and int - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.pow(a_ignore, 2.0).dtype - self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast) + # Compare the final offsets of states for both generators to ensure consistency + self.assertTrue(offset == graph_offset) + # Compare the states generated outside and inside the graph + self.assertEqual(random_values, graphed_random_values) - # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.sum(a_ignore).dtype - self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_memory_stats_of_multiple_generators_and_graphs(self): + # Function to clear CUDA cache and collect garbage + def clear_cuda_cache(): + gc.collect() + torch.cuda.empty_cache() - # Tests if CastPolicy::fp32_append_dtype ops ignore double and int - # Currently, no ops belonging to this policy support integer inputs. - if ignore_type is torch.double: - with torch.autocast("cuda", enabled=False): - type_no_autocast = torch.norm(a_ignore).dtype - self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast) + # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. + def simple_graph_task(graph): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + graph.capture_begin() + torch.rand(1, device="cuda") + graph.capture_end() + torch.cuda.current_stream().wait_stream(s) + graph.replay() # Replays the captured operations - def test_autocast_custom_enabled(self): - class MyMM(torch.autograd.Function): - @staticmethod - @torch.amp.custom_fwd(device_type="cuda") - def forward(ctx, a, b): - self.assertTrue(a.dtype is torch.float32) - self.assertTrue(b.dtype is torch.float32) - self.assertTrue(torch.is_autocast_enabled()) - ctx.save_for_backward(a, b) - return a.mm(b) + def get_memory_stats(): + stats = torch.cuda.memory_stats() + num_blocks = stats["active.all.current"] + total_size = stats["active_bytes.all.current"] + return num_blocks, total_size - @staticmethod - @torch.amp.custom_bwd(device_type="cuda") - def backward(ctx, grad): - self.assertTrue(torch.is_autocast_enabled()) - a, b = ctx.saved_tensors - a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) - self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) - return a_grad, b_grad + def test(num_graphs, num_generators): + baseline = get_memory_stats() + baseline_num_blocks, baseline_total_size = baseline - mymm = MyMM.apply + # Allocate CUDA graphs + graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)] - x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) - y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) + # Allocate and manage generator states + default_generator = torch.cuda.default_generators[0] + generators = [default_generator.graphsafe_get_state()] - dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) - for dtype in dtypes: - with torch.cuda.amp.autocast(dtype=dtype): - output = mymm(x, y) - self.assertTrue(output.dtype is dtype) - loss = output.sum() - loss.backward() + # Starts from 1 as one state is already added + for _ in range(1, num_generators): + generators.append(default_generator.clone_state()) - def test_autocast_custom_cast_inputs(self): - class MyMM(torch.autograd.Function): - @staticmethod - @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(ctx, a, container, expect_type): - b = container[1][0] - self.assertTrue(a.dtype is expect_type) - self.assertTrue(b.dtype is expect_type) - self.assertFalse(torch.is_autocast_enabled()) - ctx.save_for_backward(a, b) - return a.mm(b) + for graph in graphs: + for generator_state in generators: + graph.register_generator_state(generator_state) + simple_graph_task(graph) - @staticmethod - @torch.amp.custom_bwd(device_type="cuda") - def backward(ctx, grad): - self.assertFalse(torch.is_autocast_enabled()) - a, b = ctx.saved_tensors - return grad.mm(b.t()), None, None + # Assert conditions after graph tasks + num_blocks, total_size = get_memory_stats() + # The allocated blocks should only be proportional to the number of generators + expected_blocks_diff = 2 * num_generators + expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 - mymm = MyMM.apply + self.assertTrue( + (num_blocks - baseline_num_blocks) == expected_blocks_diff, + "Unexpected number of active blocks.", + ) + self.assertTrue( + (total_size - baseline_total_size) == expected_size_diff, + "Unexpected total memory size.", + ) - x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) - # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient, - # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments. - # Sets requires_grad=False explicitly so we don't lie about expecting a gradient. - y = ( - 0, - { - 0: torch.randn( - (8, 8), device="cuda", dtype=torch.float16, requires_grad=False - ) - }, - ) + # Cleanup graphs and clear CUDA cache + while graphs: + graph = graphs.pop() + del graph + clear_cuda_cache() - with torch.autocast("cuda"): - output = mymm(x, y, torch.float32) - self.assertTrue(output.dtype is torch.float32) - loss = output.sum() - loss.backward() + # Assert that memory stats return to baseline after cleanup + self.assertTrue( + get_memory_stats() == baseline, + "Memory stats do not match baseline after cleanup.", + ) - # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. - output = mymm(x, y, torch.float16) - self.assertTrue(output.dtype is torch.float16) - loss = output.sum() - loss.backward() + # Running the test function with different parameters + test(1, 1) + test(3, 2) + test(10, 20) - def test_autocast_custom_deprecated_warning(self): - with warnings.catch_warnings(record=True) as w: + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_capture_reset_recapture(self): + s = torch.cuda.Stream() - class MyMM(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) - def forward(ctx, x, y): - ctx.save_for_backward(x, y) - self.assertFalse(torch.is_autocast_enabled()) - return x + y + with torch.cuda.stream(s): + a = torch.full((1000,), 1, device="cuda") + g = torch.cuda.CUDAGraph() + torch.cuda.empty_cache() + g.capture_begin() + b = a + for _ in range(10): + b = b + 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - _, _ = ctx.saved_tensors - self.assertFalse(torch.is_autocast_enabled()) - return grad, grad + g.replay() - self.assertRegex( - str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." - ) - self.assertRegex( - str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." - ) + self.assertTrue(b.sum().item() == 11000.0) - mymm = MyMM.apply - x = torch.randn(3, 3, requires_grad=True) - y = torch.randn(3, 3, requires_grad=True) - with torch.amp.autocast("cuda"): - output = mymm(x, y) - loss = output.sum() - loss.backward() + g.reset() - def test_autocast_cat_jit(self): - # Reported at https://github.com/pytorch/pytorch/issues/38958 + with torch.cuda.stream(s): + g.capture_begin() + b.fill_(2.0) + for _ in range(10): + b = b + 2 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - class Model(torch.nn.Module): - def forward(self): - a = torch.randn(1) - b = torch.randn(1) - c = torch.cat((a, b), 0) - d = torch.stack([c, c], 0) - return d + g.replay() + self.assertTrue(b.sum().item() == 22000.0) - # The JIT here doesn't really matter, we just need to call - # cat via the boxed API - model = Model() - model_jit_script = torch.jit.script(model) + g.reset() + del g - with torch.autocast("cuda", enabled=True): - model() - model_jit_script() + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_debugdump(self): + torch.cuda.empty_cache() + x = torch.randn(10240000, device="cuda") + y = torch.rand_like(x) + g = torch.cuda.CUDAGraph() + g.enable_debug_mode() + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + s0.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s0): + g.capture_begin() + z = x + y + with torch.cuda.stream(s1): + s1.wait_stream(s0) + w = z + y + s0.wait_stream(s1) + g.capture_end() + s0.synchronize() + torch.cuda.synchronize() + with tempfile.TemporaryDirectory() as tempdir: + g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) - # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) - # so they get a dedicated test. - # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). - @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") - def test_autocast_rnn(self): - with torch.backends.cudnn.flags(enabled=True, deterministic=True): - # seq, batch, features, hidden size - clses = ("RNN", "GRU", "LSTM") - T, B, F, H = 3, 4, 5, 6 - dtypes = (torch.float16, torch.float32) - input_layouts = ("seq_first", "batch_first", "packed") + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_error(self): + # We need to run this test in a separate thread as the error we trigger + # puts the cuda context in a bad state + script = """ +import torch - for ( - cls, - num_layers, - bias, - input_layout, - bidirectional, - try_nonpreflattened_weights, - input_dtype, - hidden_dtype, - weight_dtype, - ) in product( - clses, - (1, 2), - (True, False), - input_layouts, - (True, False), - (True, False), - dtypes, - dtypes, - dtypes, - ): - if input_layout == "seq_first": - batch_first = False - x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) - elif input_layout == "batch_first": - batch_first = True - x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) - elif input_layout == "packed": - batch_first = False - x = torch.nn.utils.rnn.pack_padded_sequence( - torch.randn((T, B, F), device="cuda", dtype=input_dtype), - lengths=(3, 2, 1, 3), - enforce_sorted=False, - ) - - rnn = ( - getattr(torch.nn, cls)( - F, - H, - num_layers=num_layers, - bidirectional=bidirectional, - bias=bias, - batch_first=batch_first, - ) - .cuda() - .to(dtype=weight_dtype) +g = torch.cuda.CUDAGraph() +try: + g.capture_begin() +except RuntimeError as e: + if "CUDA graphs must be captured on a non-default stream." in str(e): + exit(0) + else: + exit(1) +exit(2) +""" + try: + a = subprocess.check_output( + [sys.executable, "-c", script], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)), + ) + except subprocess.CalledProcessError as e: + if e.returncode == 1: + self.assertTrue( + False, + "Error raise by starting capture without a stream is not the expected one", ) - - if try_nonpreflattened_weights: - for p in rnn.parameters(): - with torch.no_grad(): - p.set_(p.clone()) - - h = torch.randn( - (num_layers * (2 if bidirectional else 1), B, H), - device="cuda", - dtype=hidden_dtype, + elif e.returncode == 2: + self.assertTrue( + False, + "Error raised by starting capture without a stream was not caught", ) - if cls == "LSTM": - c = torch.randn( - (num_layers * (2 if bidirectional else 1), B, H), - device="cuda", - dtype=hidden_dtype, - ) - h = (h, c) - with torch.autocast("cuda"): - out, h_out = rnn(x, h) - out = out.data if input_layout == "packed" else out - self.assertEqual(out.dtype, torch.float16) - # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee - # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has - # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. - self.assertEqual( - out.grad_fn.name(), - "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0", - ) - out.sum().backward() - grads = [p.grad.clone() for p in rnn.parameters()] + @unittest.skipIf( + (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, + "CUDA >= 11.0 required for graphs", + ) + def test_graph_warn_if_has_zero_nodes(self): + with warnings.catch_warnings(record=True) as caught: + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + with torch.cuda.stream(s): + g.capture_begin() + g.capture_end() + self.assertTrue( + any("The CUDA Graph is empty" in str(w.message) for w in caught) + ) - rnn.zero_grad() + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @unittest.skipIf( + IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" + ) + def test_graph_capture_oom(self): + oom_regex = ( + "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" + ) + with self.assertRaisesRegex(RuntimeError, oom_regex): + with torch.cuda.graph(torch.cuda.CUDAGraph()): + torch.zeros(2**40, device="cuda") - if cls == "LSTM": - out_control, h_out_control = rnn.to(dtype=torch.float16)( - x.half(), (h[0].half(), h[1].half()) - ) - else: - out_control, h_out_control = rnn.to(dtype=torch.float16)( - x.half(), h.half() - ) - out_control = ( - out_control.data if input_layout == "packed" else out_control - ) - out_control.sum().backward() - grads_control = [p.grad.clone() for p in rnn.parameters()] + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @serialTest() + def test_repeat_graph_capture_cublas_workspace_memory(self): + (x, y, z) = 1024, 512, 64 + a = torch.rand((x, y), device="cuda") + b = torch.rand((y, z), device="cuda") - # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, - # autocast and control results should be bitwise identical. - self.assertEqual(out, out_control) + # warmup + torch.mm(a, b) - if cls == "LSTM": - self.assertTrue( - h_out[0].dtype is torch.float16 - and h_out[1].dtype is torch.float16 - ) - self.assertEqual(h_out[0], h_out_control[0]) - self.assertEqual(h_out[1], h_out_control[1]) - else: - self.assertEqual(h_out.dtype, torch.float16) - self.assertEqual(h_out, h_out_control) - for grad, grad_control in zip(grads, grads_control): - self.assertEqual(grad.half(), grad_control) + free_bytes_before, total_bytes = torch.cuda.mem_get_info() + used_gb_before = (total_bytes - free_bytes_before) / 1e9 - def test_autocast_cache_leak(self): - # Reported at https://github.com/pytorch/pytorch/issues/48049 - # Test is used to check, if autocast recaches the same parameters - # when executed in a `torch.no_grad()` block. + for i in range(100): + torch_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(torch_graph): + torch.mm(a, b) + torch_graph.replay() - linear = torch.nn.Linear(10, 10).to("cuda") - data = torch.randn(1, 10, device="cuda") + free_bytes_after, _ = torch.cuda.mem_get_info() + used_gb_after = (total_bytes - free_bytes_after) / 1e9 - with torch.autocast("cuda"): - with torch.no_grad(): - out = linear(data) - first_iter_mem = torch.cuda.memory_allocated() - for _ in range(3): - out = linear(data) - self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) + self.assertFalse(used_gb_before + 0.1 < used_gb_after) - def test_autocast_checkpointing(self): - model = torch.nn.Sequential( - torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8) - ).cuda() - input = torch.rand( - (8, 8), device="cuda", dtype=torch.float16, requires_grad=True + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_rng_functional(self): + ops_with_kwargs = ( + (torch.nn.functional.dropout, {"p": 0.1}), + (torch.nn.functional.rrelu, {"training": True}), ) - for reentrant in (True, False): - with torch.autocast("cuda"): - output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant) - self.assertTrue(output.requires_grad) - self.assertTrue(output.dtype is torch.float16) - output.sum().backward() - - def test_cuda_autocast_deprecated_warning(self): - with self.assertWarnsRegex( - FutureWarning, - r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", - ): - with torch.cuda.amp.autocast(): - _ = torch.ones(10) + size = 10000 - @slowTest - @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") - @serialTest() - def test_max_large_axis(self): - x = torch.zeros(2**32, device="cuda", dtype=torch.int8) - x[-1] = 1 - val, idx = x.max(0) - self.assertEqual(val, 1) - self.assertEqual(idx, x.shape[0] - 1) + def run(op, kwargs): + a = torch.randn((size,), device="cuda", dtype=torch.float) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_to_numpy(self): - self.assertRaises(TypeError, lambda: torch.empty(1, device="cuda").numpy()) + # Control + torch.cuda.manual_seed(5) + eager_out = a + for _ in range(6): + eager_out = op(eager_out, **kwargs) - def test_graph_is_current_stream_capturing(self): - self.assertFalse(torch.cuda.is_current_stream_capturing()) + graph_in = a.clone() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) - if TEST_CUDA and (not TEST_WITH_ROCM): - s = torch.cuda.Stream() - with torch.cuda.stream(s): g = torch.cuda.CUDAGraph() - self.assertFalse(torch.cuda.is_current_stream_capturing()) + torch.cuda.empty_cache() g.capture_begin() - self.assertTrue(torch.cuda.is_current_stream_capturing()) + graph_out = graph_in + for _ in range(2): + graph_out = op(graph_out, **kwargs) g.capture_end() + torch.cuda.current_stream().wait_stream(stream) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_capture_simple(self): - s = torch.cuda.Stream() - - with torch.cuda.stream(s): - a = torch.full((1000,), 1, device="cuda") - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - g.capture_begin() - b = a - for _ in range(10): - b = b + 1 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + # Runs a graphed->eager->graphed sequence of RNG ops. + # replay() plays 2 invocations of the op, so the sequence has 6 + # invocations total, matching Control. + # replay() reads from graph_in and writes to graph_out. + g.replay() + out = op(graph_out, **kwargs) + out = op(out, **kwargs) + graph_in.copy_(out) + g.replay() - g.replay() + # If replay() updated RNG state correctly, graph_out + # should now hold data equal to eager_out. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e - self.assertTrue(b.sum().item() == 11000.0) + # Do the same operations varying seeds + seeds = [6, 128, 9999] + + for seed in seeds: + torch.cuda.manual_seed(seed) + graph_in.copy_(a) + for _ in range(3): + g.replay() + + # If the random seed was not updated then the graph would + # generate the same output as in previous check. + try: + self.assertNotEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # Now repeat the same operations in non-graphed mode. + torch.cuda.manual_seed(seed) + for _ in range(3): + eager_out.copy_(a) + eager_out = op(eager_out, **kwargs) + eager_out = op(eager_out, **kwargs) + + # In the end, graph_out and eager_out must be equal + # as they went under the same set of operations. + try: + self.assertEqual(eager_out, graph_out) + except Exception as e: + raise RuntimeError("Failed on ", op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op, kwargs in ops_with_kwargs: + run(op, kwargs) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graphsafe_set_get_rng_state(self): - # Define a function to create generator states, with optional graph registration - def create_states(generator): - """Initializes generator states and registers them with a CUDA graph if provided.""" - # Ensure the CUDA generator is initialized - torch.rand(1, device="cuda") - generator.manual_seed(0) + def test_graph_rng_distributions(self): + size = 10000 + input = torch.rand((size,), device="cuda", dtype=torch.float) + alloc = torch.empty((size,), device="cuda", dtype=torch.float) - # Save the current state of the generator - old_state = generator.graphsafe_get_state() - # Create and save a cloned state of the generator - new_state = generator.clone_state() - # Return the original generator and its two states - return generator, old_state, new_state + # Torch ops to test with sample args (tuple) and kwargs (dict) + torch_with_args = ( + ("bernoulli", (input.clone(),), {}), + # multinomial uses some uncapturable CUDA calls. + # TODO: reenable multinomial tests if/when the implementation is capturable. + # ("multinomial", (input.clone(), size, True), {}), + # ("multinomial", (input.clone(), size // 2, False), {}), + # TODO: reenable normal test, where std is a device + # tensor, when graph test failures are fixed + # ("normal", (input.clone() + 1, input.clone()), {}), + ("normal", (input.clone() + 1, 1.0), {}), + ("poisson", (input.clone(),), {}), + ("rand", (size,), {"device": "cuda", "dtype": torch.float}), + ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), + ("randn", (size,), {"device": "cuda", "dtype": torch.float}), + ) - def register_states_to_graph(generator_state, graph): - generator, old_state, new_state = generator_state - graph.register_generator_state(old_state) - graph.register_generator_state(new_state) + # Tensor methods to test with sample args (tuple) + tensor_with_args = ( + ("bernoulli_", (input.clone(),)), + ("cauchy_", ()), + ("exponential_", ()), + ("geometric_", (0.3,)), + ("log_normal_", ()), + ("normal_", ()), + ("random_", ()), + ("uniform_", ()), + ) - # Define a function to perform specific RNG actions using the generator's states - def perform_random_generation_steps(generator_state): - generator, old_state, new_state = generator_state - random_values = [] + def run(module, op, args, kwargs): + torch.cuda.manual_seed(5) - # Generate random numbers with the new generator state - generator.graphsafe_set_state(new_state) - random_values.append(torch.rand(5, device="cuda", generator=generator)) + # Each path runs a dummy op to increment the state a bit before creating controls. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + dummy = alloc.clone() + control1 = alloc.clone() + control2 = alloc.clone() + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) - # Generate random numbers twice with the old generator state - generator.graphsafe_set_state(old_state) - random_values.extend( - [torch.rand(5, device="cuda", generator=generator) for _ in range(2)] - ) + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + torch.cuda.manual_seed(5) - return random_values + g = torch.cuda.CUDAGraph() + torch.cuda.empty_cache() + if module == "torch": + g.capture_begin() + t1 = getattr(torch, op)(*args, **kwargs) + t2 = getattr(torch, op)(*args, **kwargs) + g.capture_end() + else: + t1 = alloc.clone() + t2 = alloc.clone() + g.capture_begin() + getattr(t1, op)(*args) + getattr(t2, op)(*args) + g.capture_end() + torch.cuda.current_stream().wait_stream(stream) - # Define a function to retrieve the final offsets of the original and new generator states - def get_final_offsets_of_states(generator_state): - generator, old_state, new_state = generator_state - old_state_offset = old_state.get_offset() - new_state_offset = new_state.get_offset() - return old_state_offset, new_state_offset + if not TEST_CUDAMALLOCASYNC: + # Makes sure values haven't been populated yet + # (in other words, makes sure capture didn't actually run ops). + # We can only try this with the native allocator, for which captured + # addresses are already backed by cudaMalloced memory. + # If we try it with cudaMallocAsync, CUDA won't event consider + # the captured addresses allocated until replay(), and if we + # access them before replay() we get IMAs. + try: + self.assertNotEqual(control1, t1) + self.assertNotEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e - # Set up and test a new CUDA generator - generator = torch.Generator(device="cuda") - generator_state = create_states(generator) + # Set a new seed to check if graph would use it + for seed in [6, 314, 271]: + torch.cuda.manual_seed(seed) + # Runs a dummy op prelude, as for controls, to make sure replay() + # picks up the dummy op's state increment. + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + control1 = getattr(torch, op)(*args, **kwargs) + control2 = getattr(torch, op)(*args, **kwargs) + else: + getattr(dummy, op)(*args) + getattr(control1, op)(*args) + getattr(control2, op)(*args) - # Set up and test the default CUDA generator with a CUDA Graph - g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() - default_generator = torch.cuda.default_generators[0] - default_generator_state = create_states(default_generator) - register_states_to_graph(default_generator_state, g) + torch.cuda.manual_seed(seed) + if module == "torch": + dummy = getattr(torch, op)(*args, **kwargs) + else: + getattr(dummy, op)(*args) - # Perform random number generation within a CUDA graph - with torch.cuda.stream(s): - g.capture_begin() - graphed_random_values = perform_random_generation_steps( - default_generator_state - ) - g.capture_end() + # see above comment on TEST_CUDAMALLOCASYNC + if not TEST_CUDAMALLOCASYNC: + t1.copy_(alloc) + t2.copy_(alloc) - # Synchronize the streams and replay the graph - torch.cuda.current_stream().wait_stream(s) - for _ in range(3): - random_values = perform_random_generation_steps(generator_state) - g.replay() - offset = get_final_offsets_of_states(generator_state) - graph_offset = get_final_offsets_of_states(default_generator_state) + # Runs RNG ops that fill t1 and t2. + g.replay() - # Compare the final offsets of states for both generators to ensure consistency - self.assertTrue(offset == graph_offset) - # Compare the states generated outside and inside the graph - self.assertEqual(random_values, graphed_random_values) + try: + self.assertEqual(control1, t1) + self.assertEqual(control2, t2) + except Exception as e: + raise RuntimeError("Failed on " + module + "." + op) from e + + # We hold references to all tensors used across streams up til this sync, + # so no need to call record_stream on those tensors. + torch.cuda.synchronize() + + for op_with_args in torch_with_args: + run("torch", *op_with_args) + + for meth_with_args in tensor_with_args: + # Adds an empty dict for kwargs, which none of the Tensor methods use + run("Tensor", *(meth_with_args + ({},))) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_memory_stats_of_multiple_generators_and_graphs(self): - # Function to clear CUDA cache and collect garbage - def clear_cuda_cache(): - gc.collect() - torch.cuda.empty_cache() + def test_graph_two_successive(self): + torch.cuda.empty_cache() - # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. - def simple_graph_task(graph): - s = torch.cuda.Stream() - with torch.cuda.stream(s): - graph.capture_begin() - torch.rand(1, device="cuda") - graph.capture_end() - torch.cuda.current_stream().wait_stream(s) - graph.replay() # Replays the captured operations + size = 1000 + kSmallBuffer = 2097152 - def get_memory_stats(): - stats = torch.cuda.memory_stats() - num_blocks = stats["active.all.current"] - total_size = stats["active_bytes.all.current"] - return num_blocks, total_size - - def test(num_graphs, num_generators): - baseline = get_memory_stats() - baseline_num_blocks, baseline_total_size = baseline + def func_with_temps(t, val): + x = t.clone() + val + y = t.clone() + val + return x + y - # Allocate CUDA graphs - graphs = [torch.cuda.CUDAGraph() for _ in range(num_graphs)] + s = torch.cuda.Stream() - # Allocate and manage generator states - default_generator = torch.cuda.default_generators[0] - generators = [default_generator.graphsafe_get_state()] + for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() - # Starts from 1 as one state is already added - for _ in range(1, num_generators): - generators.append(default_generator.clone_state()) + a = torch.ones((size,), device="cuda") - for graph in graphs: - for generator_state in generators: - graph.register_generator_state(generator_state) - simple_graph_task(graph) + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + g0_args = ( + (torch.cuda.graph_pool_handle(),) + if share_mem == "via graph_pool_handle()" + else () + ) + g0.capture_begin(*g0_args) + b = a.clone() + for _ in range(5): + b = func_with_temps(b, 1) + g0.capture_end() - # Assert conditions after graph tasks - num_blocks, total_size = get_memory_stats() - # The allocated blocks should only be proportional to the number of generators - expected_blocks_diff = 2 * num_generators - expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 + g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args + g1.capture_begin(*g1_args) + for _ in range(5): + b = func_with_temps(b, 1) + g1.capture_end() + torch.cuda.current_stream().wait_stream(s) - self.assertTrue( - (num_blocks - baseline_num_blocks) == expected_blocks_diff, - "Unexpected number of active blocks.", - ) - self.assertTrue( - (total_size - baseline_total_size) == expected_size_diff, - "Unexpected total memory size.", - ) + # mixes unrelated eager ops with replays + c = a.clone() + for _ in range(2): + c = func_with_temps(c, 3) + g0.replay() + for _ in range(2): + c = func_with_temps(c, 3) + g1.replay() + for _ in range(2): + c = func_with_temps(c, 3) - # Cleanup graphs and clear CUDA cache - while graphs: - graph = graphs.pop() - del graph - clear_cuda_cache() + self.assertEqual(b.sum().item(), size * 3070) + self.assertEqual(c.sum().item(), size * 442) - # Assert that memory stats return to baseline after cleanup - self.assertTrue( - get_memory_stats() == baseline, - "Memory stats do not match baseline after cleanup.", - ) + if not TEST_CUDAMALLOCASYNC: + # These stat checks are specific to the native allocator. + if share_mem != "Don't share": + self.assertEqual( + reserved_no_sharing # noqa: F821 + - torch.cuda.memory_stats()["reserved_bytes.all.current"], + kSmallBuffer, + ) + else: + reserved_no_sharing = torch.cuda.memory_stats()[ + "reserved_bytes.all.current" + ] - # Running the test function with different parameters - test(1, 1) - test(3, 2) - test(10, 20) + del a, b, c, g0, g1 + # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() + @unittest.skipIf( + (not TEST_CUDA_GRAPH) + or IS_WINDOWS + or ( # appears to still be broken on Windows as of 11.4+ + torch.version.cuda + and int(torch.version.cuda.split(".")[0]) == 11 + and int(torch.version.cuda.split(".")[1]) < 4 + ), + "Graph bindings disallow concurrent replay for CUDA < 11.4, see " + + "https://github.com/pytorch/pytorch/pull/57556", + ) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_capture_reset_recapture(self): + def test_graph_concurrent_replay(self): + torch.cuda.empty_cache() + + size = 1000000 # largeish to help expose race conditions + + def func_with_temps(t, val): + x = t.clone() + val + y = t.clone() + val + return x + y + s = torch.cuda.Stream() - with torch.cuda.stream(s): - a = torch.full((1000,), 1, device="cuda") - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - g.capture_begin() - b = a - for _ in range(10): - b = b + 1 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() - g.replay() + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() - self.assertTrue(b.sum().item() == 11000.0) + a = torch.ones((size,), device="cuda") - g.reset() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + g0_args = ( + (torch.cuda.graph_pool_handle(),) + if share_mem == "via graph_pool_handle()" + else () + ) + g0.capture_begin(*g0_args) + b = a.clone() + for _ in range(5): + b = func_with_temps(b, 1) + g0.capture_end() - with torch.cuda.stream(s): - g.capture_begin() - b.fill_(2.0) - for _ in range(10): - b = b + 2 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args + g1.capture_begin(*g1_args) + c = a.clone() + for _ in range(5): + c = func_with_temps(c, 2) + g1.capture_end() - g.replay() - self.assertTrue(b.sum().item() == 22000.0) + # To reproduce data corruption, I need g0 and g1's kernels to run concurrently. + # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead. + # The following pattern helps align device-side execution of g0 and g1's kernels. + torch.cuda.synchronize() + with torch.cuda.stream(s0): + torch.cuda._sleep(1000000) + s1.wait_stream(s0) + g0.replay() + with torch.cuda.stream(s1): + g1.replay() + torch.cuda.current_stream().wait_stream(s0) + torch.cuda.current_stream().wait_stream(s1) - g.reset() - del g + if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"): + # If we used the native allocator and shared mempools, + # we expect the concurrent replays corrupted each other. + self.assertNotEqual(b.sum().item(), size * 94) + self.assertNotEqual(c.sum().item(), size * 156) + else: + # If we EITHER + # - used the native allocator without sharing mempools, OR + # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe + # we don't expect memory corruption. + self.assertEqual(b.sum().item(), size * 94) + self.assertEqual(c.sum().item(), size * 156) + + del a, b, c, g0, g1 + # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_debugdump(self): + def test_graph_three_successive(self): torch.cuda.empty_cache() - x = torch.randn(10240000, device="cuda") - y = torch.rand_like(x) - g = torch.cuda.CUDAGraph() - g.enable_debug_mode() - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() - s0.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s0): - g.capture_begin() - z = x + y - with torch.cuda.stream(s1): - s1.wait_stream(s0) - w = z + y - s0.wait_stream(s1) - g.capture_end() - s0.synchronize() - torch.cuda.synchronize() - with tempfile.TemporaryDirectory() as tempdir: - g.debug_dump(os.path.join(tempdir, "out_multi_stream.dot")) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_error(self): - # We need to run this test in a separate thread as the error we trigger - # puts the cuda context in a bad state - script = """ -import torch + size = 1000 -g = torch.cuda.CUDAGraph() -try: - g.capture_begin() -except RuntimeError as e: - if "CUDA graphs must be captured on a non-default stream." in str(e): - exit(0) - else: - exit(1) -exit(2) -""" - try: - a = subprocess.check_output( - [sys.executable, "-c", script], - stderr=subprocess.STDOUT, - # On Windows, opening the subprocess with the default CWD makes `import torch` - # fail, so just set CWD to this script's directory - cwd=os.path.dirname(os.path.realpath(__file__)), - ) - except subprocess.CalledProcessError as e: - if e.returncode == 1: - self.assertTrue( - False, - "Error raise by starting capture without a stream is not the expected one", - ) - elif e.returncode == 2: - self.assertTrue( - False, - "Error raised by starting capture without a stream was not caught", - ) + s = torch.cuda.Stream() - @unittest.skipIf( - (not TEST_CUDA) or TEST_WITH_ROCM or int(torch.version.cuda.split(".")[0]) < 11, - "CUDA >= 11.0 required for graphs", - ) - def test_graph_warn_if_has_zero_nodes(self): - with warnings.catch_warnings(record=True) as caught: - g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() + for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): + a = torch.ones((size,), device="cuda") + + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() + g2 = torch.cuda.CUDAGraph() + + s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - g.capture_begin() - g.capture_end() - self.assertTrue( - any("The CUDA Graph is empty" in str(w.message) for w in caught) - ) + g0_args = ( + (torch.cuda.graph_pool_handle(),) + if share_mem == "via graph_pool_handle()" + else () + ) + g0.capture_begin(*g0_args) + b = a.clone() + c = b + 1 + d = b + 2 + g0.capture_end() - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @unittest.skipIf( - IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" - ) - def test_graph_capture_oom(self): - oom_regex = ( - "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else "out of memory" - ) - with self.assertRaisesRegex(RuntimeError, oom_regex): - with torch.cuda.graph(torch.cuda.CUDAGraph()): - torch.zeros(2**40, device="cuda") + args = (g0.pool(),) if share_mem == "via pool()" else g0_args - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @serialTest() - def test_repeat_graph_capture_cublas_workspace_memory(self): - (x, y, z) = 1024, 512, 64 - a = torch.rand((x, y), device="cuda") - b = torch.rand((y, z), device="cuda") + g1.capture_begin(*args) + e = c + 3 + del c + g1.capture_end() - # warmup - torch.mm(a, b) + g2.capture_begin(*args) + f = d + 4 + g2.capture_end() + torch.cuda.current_stream().wait_stream(s) - free_bytes_before, total_bytes = torch.cuda.mem_get_info() - used_gb_before = (total_bytes - free_bytes_before) / 1e9 + # Tests that replaying in capture order is valid + g0.replay() + g1.replay() + g2.replay() - for i in range(100): - torch_graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(torch_graph): - torch.mm(a, b) - torch_graph.replay() + self.assertEqual(e.sum().item(), size * 5) + self.assertEqual(f.sum().item(), size * 7) - free_bytes_after, _ = torch.cuda.mem_get_info() - used_gb_after = (total_bytes - free_bytes_after) / 1e9 + # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool + g0.replay() + g2.replay() + g1.replay() - self.assertFalse(used_gb_before + 0.1 < used_gb_after) + expect_corruption = (not TEST_CUDAMALLOCASYNC) and ( + share_mem != "Don't share" + ) + # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f. + # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". + self.assertEqual( + e.sum().item(), size * (7 + 3) if expect_corruption else size * 5 + ) + self.assertEqual(f.sum().item(), size * 7) + + del a, b, d, e, f, g0, g1, g2 + # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC, + "CUDA >= 11.0 or ROCM >= 5.3 required for graphs", ) - def test_graph_rng_functional(self): - ops_with_kwargs = ( - (torch.nn.functional.dropout, {"p": 0.1}), - (torch.nn.functional.rrelu, {"training": True}), - ) - size = 10000 - - def run(op, kwargs): - a = torch.randn((size,), device="cuda", dtype=torch.float) + def test_graph_memory_stats_and_use_result_after_destroy_graph(self): + kSmallSize = 1048576 + kSmallBuffer = 2097152 + kLargeBuffer = 20971520 + kMinLargeAlloc = 10485760 + kRoundLarge = 2097152 - # Control - torch.cuda.manual_seed(5) - eager_out = a - for _ in range(6): - eager_out = op(eager_out, **kwargs) + elem = 4 - graph_in = a.clone() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - torch.cuda.manual_seed(5) + # this was annoying to write but stresses the expectations pretty rigorously + cases = ( + (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), + (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), + ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), + ( + (kMinLargeAlloc - 512) // elem, + 2, + 2 * kLargeBuffer, + kLargeBuffer, + "large_pool", + ), + ( + (kMinLargeAlloc + 512) // elem, + 3, + 3 + * ( + kRoundLarge + * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge) + ), + kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge), + "large_pool", + ), + ) - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - g.capture_begin() - graph_out = graph_in - for _ in range(2): - graph_out = op(graph_out, **kwargs) - g.capture_end() - torch.cuda.current_stream().wait_stream(stream) + stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.") - # Runs a graphed->eager->graphed sequence of RNG ops. - # replay() plays 2 invocations of the op, so the sequence has 6 - # invocations total, matching Control. - # replay() reads from graph_in and writes to graph_out. - g.replay() - out = op(graph_out, **kwargs) - out = op(out, **kwargs) - graph_in.copy_(out) - g.replay() + gc.collect() + torch.cuda.empty_cache() - # If replay() updated RNG state correctly, graph_out - # should now hold data equal to eager_out. - try: - self.assertEqual(eager_out, graph_out) - except Exception as e: - raise RuntimeError("Failed on ", op) from e + s = torch.cuda.Stream() - # Do the same operations varying seeds - seeds = [6, 128, 9999] + for ( + numel, + delta_cudaMallocs, + delta_cudaMalloc_bytes, + delta_cudaMalloc_bytes_post_del_g, + pool_string, + ) in cases: + if pool_string == "small_pool": + delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders + delta_active_bytes = ( + numel * elem + 1024 + ) # + 1024 for CUDAGraph's rng seed and offset holders each + else: + delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder + delta_active_bytes = numel * elem - for seed in seeds: - torch.cuda.manual_seed(seed) - graph_in.copy_(a) - for _ in range(3): - g.replay() + g = torch.cuda.CUDAGraph() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + # Allocation stat estimates assume input is created on the same stream as capture_begin() + # (in other words, the same stream silo as the rng offset holder, which is not allocated from the + # capture's private pool). + a = torch.ones((numel,), device="cuda") - # If the random seed was not updated then the graph would - # generate the same output as in previous check. - try: - self.assertNotEqual(eager_out, graph_out) - except Exception as e: - raise RuntimeError("Failed on ", op) from e + precapture_stats = torch.cuda.memory_stats() - # Now repeat the same operations in non-graphed mode. - torch.cuda.manual_seed(seed) - for _ in range(3): - eager_out.copy_(a) - eager_out = op(eager_out, **kwargs) - eager_out = op(eager_out, **kwargs) + g.capture_begin() + b = a.clone() + for _ in range(5): + b = b.clone() + 1 + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - # In the end, graph_out and eager_out must be equal - # as they went under the same set of operations. - try: - self.assertEqual(eager_out, graph_out) - except Exception as e: - raise RuntimeError("Failed on ", op) from e + gc.collect() - # We hold references to all tensors used across streams up til this sync, - # so no need to call record_stream on those tensors. - torch.cuda.synchronize() + postcapture_stats = torch.cuda.memory_stats() - for op, kwargs in ops_with_kwargs: - run(op, kwargs) + expecteds = ( + delta_cudaMallocs, + delta_cudaMalloc_bytes, + delta_active_blocks, + delta_active_bytes, + ) + # Double checks replay and stats before and after a call to empty_cache + for i in range(2): + for stat, expected in zip(stats_to_check, expecteds): + stat = stat + pool_string + ".current" + current = postcapture_stats[stat] - precapture_stats[stat] - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_rng_distributions(self): - size = 10000 - input = torch.rand((size,), device="cuda", dtype=torch.float) - alloc = torch.empty((size,), device="cuda", dtype=torch.float) - - # Torch ops to test with sample args (tuple) and kwargs (dict) - torch_with_args = ( - ("bernoulli", (input.clone(),), {}), - # multinomial uses some uncapturable CUDA calls. - # TODO: reenable multinomial tests if/when the implementation is capturable. - # ("multinomial", (input.clone(), size, True), {}), - # ("multinomial", (input.clone(), size // 2, False), {}), - # TODO: reenable normal test, where std is a device - # tensor, when graph test failures are fixed - # ("normal", (input.clone() + 1, input.clone()), {}), - ("normal", (input.clone() + 1, 1.0), {}), - ("poisson", (input.clone(),), {}), - ("rand", (size,), {"device": "cuda", "dtype": torch.float}), - ("randint", (0, 3, (size,)), {"device": "cuda", "dtype": torch.float}), - ("randn", (size,), {"device": "cuda", "dtype": torch.float}), - ) + # There will only ever be one expandable segment in each of the small and large pools. The way the + # bookeeping is done in the allocator means that we never increment the number of segments. + if self.expandable_segments and "segment" in stat: + expected = 0 + # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an + # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is + # smaller than the page size + if ( + self.expandable_segments + and "reserved" in stat + and (numel == cases[3][0] or numel == cases[4][0]) + ): + expected = 2 * kLargeBuffer - # Tensor methods to test with sample args (tuple) - tensor_with_args = ( - ("bernoulli_", (input.clone(),)), - ("cauchy_", ()), - ("exponential_", ()), - ("geometric_", (0.3,)), - ("log_normal_", ()), - ("normal_", ()), - ("random_", ()), - ("uniform_", ()), - ) + self.assertEqual( + current, + expected, + "Pre to post capture delta of " + + stat + + f" = {current}, expected = {expected}, numel = {numel}", + ) - def run(module, op, args, kwargs): - torch.cuda.manual_seed(5) + g.replay() + self.assertEqual(b.sum().item(), 6 * numel) + if i == 0: + torch.cuda.empty_cache() - # Each path runs a dummy op to increment the state a bit before creating controls. - if module == "torch": - dummy = getattr(torch, op)(*args, **kwargs) - control1 = getattr(torch, op)(*args, **kwargs) - control2 = getattr(torch, op)(*args, **kwargs) - else: - dummy = alloc.clone() - control1 = alloc.clone() - control2 = alloc.clone() - getattr(dummy, op)(*args) - getattr(control1, op)(*args) - getattr(control2, op)(*args) + del g + gc.collect() + torch.cuda.empty_cache() + postdel_stats = torch.cuda.memory_stats() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - torch.cuda.manual_seed(5) + # Uses graph result b after graph has been deleted + self.assertEqual(b.sum().item(), 6 * numel) - g = torch.cuda.CUDAGraph() - torch.cuda.empty_cache() - if module == "torch": - g.capture_begin() - t1 = getattr(torch, op)(*args, **kwargs) - t2 = getattr(torch, op)(*args, **kwargs) - g.capture_end() - else: - t1 = alloc.clone() - t2 = alloc.clone() - g.capture_begin() - getattr(t1, op)(*args) - getattr(t2, op)(*args) - g.capture_end() - torch.cuda.current_stream().wait_stream(stream) + # b should be the only live reference remaining from the graph's private pool + expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem) + for stat, expected in zip(stats_to_check, expecteds): + stat = stat + pool_string + ".current" + current = postdel_stats[stat] - precapture_stats[stat] - if not TEST_CUDAMALLOCASYNC: - # Makes sure values haven't been populated yet - # (in other words, makes sure capture didn't actually run ops). - # We can only try this with the native allocator, for which captured - # addresses are already backed by cudaMalloced memory. - # If we try it with cudaMallocAsync, CUDA won't event consider - # the captured addresses allocated until replay(), and if we - # access them before replay() we get IMAs. - try: - self.assertNotEqual(control1, t1) - self.assertNotEqual(control2, t2) - except Exception as e: - raise RuntimeError("Failed on " + module + "." + op) from e + # There will only ever be one expandable segment in each of the small and large pools. The way the + # bookeeping is done in the allocator means that we never increment the number of segments. + if self.expandable_segments and "segment" in stat: + expected = 0 + # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an + # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is + # smaller than the page size + if ( + self.expandable_segments + and "reserved" in stat + and numel == cases[3][0] + ): + expected = 2 * kLargeBuffer + if ( + self.expandable_segments + and "reserved" in stat + and numel == cases[4][0] + ): + expected = kLargeBuffer - # Set a new seed to check if graph would use it - for seed in [6, 314, 271]: - torch.cuda.manual_seed(seed) - # Runs a dummy op prelude, as for controls, to make sure replay() - # picks up the dummy op's state increment. - if module == "torch": - dummy = getattr(torch, op)(*args, **kwargs) - control1 = getattr(torch, op)(*args, **kwargs) - control2 = getattr(torch, op)(*args, **kwargs) - else: - getattr(dummy, op)(*args) - getattr(control1, op)(*args) - getattr(control2, op)(*args) + self.assertEqual( + current, + expected, + "Pre capture to post graph delete delta of " + + stat + + f" = {current}, expected = {expected}, numel = {numel}", + ) - torch.cuda.manual_seed(seed) - if module == "torch": - dummy = getattr(torch, op)(*args, **kwargs) - else: - getattr(dummy, op)(*args) + # del a, b before the next case is essential, otherwise overwriting a and b in the next case + # can throw off its allocation/deallocation counts. + del a, b + # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. + torch.cuda.synchronize() + torch.cuda.empty_cache() - # see above comment on TEST_CUDAMALLOCASYNC - if not TEST_CUDAMALLOCASYNC: - t1.copy_(alloc) - t2.copy_(alloc) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_record_stream(self): + # Makes sure graph capture defers attempting to reclaim allocations used across streams. See + # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp + torch.cuda.empty_cache() - # Runs RNG ops that fill t1 and t2. - g.replay() + potential_problem = torch.zeros((3,), device="cuda") + a = torch.zeros((3,), device="cuda") + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + g = torch.cuda.CUDAGraph() - try: - self.assertEqual(control1, t1) - self.assertEqual(control2, t2) - except Exception as e: - raise RuntimeError("Failed on " + module + "." + op) from e + torch.cuda.synchronize() + with torch.cuda.stream(s0): + potential_problem.record_stream(s0) + torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES) + potential_problem.fill_(1.0) + del potential_problem - # We hold references to all tensors used across streams up til this sync, - # so no need to call record_stream on those tensors. - torch.cuda.synchronize() + with torch.cuda.stream(s1): + g.capture_begin() + # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc + # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life + # event, which will cause the capture to error. + b = a.clone() - for op_with_args in torch_with_args: - run("torch", *op_with_args) + # Let's also see what happens if we record_stream on a tensor during capture. + s2.wait_stream(s1) + with torch.cuda.stream(s2): + b.fill_(1.0) + b.record_stream(s2) # dummy record_stream + del b + s1.wait_stream(s2) + g.capture_end() + torch.cuda.synchronize() - for meth_with_args in tensor_with_args: - # Adds an empty dict for kwargs, which none of the Tensor methods use - run("Tensor", *(meth_with_args + ({},))) + # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. + c = torch.zeros((3,), device="cuda") + @skipIfRocm @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_two_successive(self): + # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize + # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior + # as a memory leak unless we skip the leak check. + @skipCUDAMemoryLeakCheckIf(True) + @serialTest() + def test_graph_cudnn_dropout(self): + # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp. + # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should + # avoid syncing noncapturing streams with captured events or vice versa. torch.cuda.empty_cache() - size = 1000 - kSmallBuffer = 2097152 + model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda() + x = torch.ones(100, 192, 512, device="cuda") - def func_with_temps(t, val): - x = t.clone() + val - y = t.clone() + val - return x + y + y = model(x) + g = torch.cuda.CUDAGraph() s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + g.capture_begin() + y = model(x) + g.capture_end() + torch.cuda.current_stream().wait_stream(s) - for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - g0 = torch.cuda.CUDAGraph() - g1 = torch.cuda.CUDAGraph() + g.replay() - a = torch.ones((size,), device="cuda") + y = model(x) - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g0_args = ( - (torch.cuda.graph_pool_handle(),) - if share_mem == "via graph_pool_handle()" - else () - ) - g0.capture_begin(*g0_args) - b = a.clone() - for _ in range(5): - b = func_with_temps(b, 1) - g0.capture_end() + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @parametrize( + "with_amp,cache_enabled,allow_unused_input", + [ + subtest((False, False, True), decorators=[skipIfRocm]), + subtest((True, False, True), decorators=[skipIfRocm]), + subtest((True, True, True), decorators=[unittest.expectedFailure]), + subtest((False, False, False), decorators=[unittest.expectedFailure]), + ], + name_fn=lambda x, y, z: "{}{}{}".format( + {True: "with_amp", False: "without_amp"}[x], + {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", + {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], + ), + ) + @serialTest() + def test_graph_make_graphed_callables( + self, with_amp, cache_enabled, allow_unused_input + ): + torch.manual_seed(5) + torch.cuda.manual_seed(5) - g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args - g1.capture_begin(*g1_args) - for _ in range(5): - b = func_with_temps(b, 1) - g1.capture_end() - torch.cuda.current_stream().wait_stream(s) + N, D_in, H, D_out = 640, 4096, 2048, 1024 - # mixes unrelated eager ops with replays - c = a.clone() - for _ in range(2): - c = func_with_temps(c, 3) - g0.replay() - for _ in range(2): - c = func_with_temps(c, 3) - g1.replay() - for _ in range(2): - c = func_with_temps(c, 3) + class MLP1(torch.nn.Module): + def __init__(self, D_in: int, H: int, D_out: int): + super().__init__() + self.net_1 = torch.nn.Sequential( + torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) + ).cuda() + self.net_2 = torch.nn.Sequential( + torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) + ).cuda() - self.assertEqual(b.sum().item(), size * 3070) - self.assertEqual(c.sum().item(), size * 442) + def forward(self, input_dict: dict): + x = input_dict["x"] + return self.net_2(self.net_1(x)) - if not TEST_CUDAMALLOCASYNC: - # These stat checks are specific to the native allocator. - if share_mem != "Don't share": - self.assertEqual( - reserved_no_sharing # noqa: F821 - - torch.cuda.memory_stats()["reserved_bytes.all.current"], - kSmallBuffer, - ) - else: - reserved_no_sharing = torch.cuda.memory_stats()[ - "reserved_bytes.all.current" - ] + class MLP2(torch.nn.Module): + def __init__(self, D_in: int, H: int, D_out: int): + super().__init__() + self.net_1 = torch.nn.Sequential( + torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) + ).cuda() + self.net_2 = torch.nn.Sequential( + torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) + ).cuda() - del a, b, c, g0, g1 - # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + def forward(self, x): + return self.net_2(self.net_1(x)) - @unittest.skipIf( - (not TEST_CUDA_GRAPH) - or IS_WINDOWS - or ( # appears to still be broken on Windows as of 11.4+ - torch.version.cuda - and int(torch.version.cuda.split(".")[0]) == 11 - and int(torch.version.cuda.split(".")[1]) < 4 - ), - "Graph bindings disallow concurrent replay for CUDA < 11.4, see " - + "https://github.com/pytorch/pytorch/pull/57556", - ) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_concurrent_replay(self): - torch.cuda.empty_cache() + class ParameterlessModule(torch.nn.Module): + def forward(self, x): + idx = ( + torch.arange(x.size(0), device=x.device) + .view(-1, 1) + .repeat(1, x.size(1)) + ) + return {"output": torch.gather(x, 0, idx)} - size = 1000000 # largeish to help expose race conditions + models = [] + for _ in range(2): + model_section1 = MLP1(D_in, H, H).cuda() + model_section2 = MLP2(H, H, D_out).cuda() + model_section3 = ParameterlessModule().cuda() + models.append( + torch.nn.Sequential(model_section1, model_section2, model_section3) + ) - def func_with_temps(t, val): - x = t.clone() + val - y = t.clone() + val - return x + y + model_graphed = models[0] + model_control = models[1] - s = torch.cuda.Stream() + model_graphed.load_state_dict(model_control.state_dict()) - for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - g0 = torch.cuda.CUDAGraph() - g1 = torch.cuda.CUDAGraph() + opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) + opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() + x = torch.randn(N, D_in, device="cuda") + h = torch.randn(N, H, device="cuda", requires_grad=True) + h2 = torch.randn(N, D_out, device="cuda", requires_grad=True) + unused_input = torch.randn(N, H, device="cuda", requires_grad=True) + y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True) + y = torch.randn(N, D_out, device="cuda") - a = torch.ones((size,), device="cuda") + loss_fn_control = torch.nn.functional.mse_loss + relu_control = torch.nn.functional.relu - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g0_args = ( - (torch.cuda.graph_pool_handle(),) - if share_mem == "via graph_pool_handle()" - else () - ) - g0.capture_begin(*g0_args) - b = a.clone() - for _ in range(5): - b = func_with_temps(b, 1) - g0.capture_end() + # This is a good stress test. It graphs four callables: two Modules and two python functions. + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + ( + model_graphed[0], + model_graphed[1], + model_graphed[2], + relu_graphed, + loss_fn_graphed, + ) = torch.cuda.make_graphed_callables( + ( + model_graphed[0], + model_graphed[1], + model_graphed[2], + relu_control, + loss_fn_control, + ), + ( + ({"x": x, "unused_input": unused_input},), + (h,), + (h2,), + (y_pred,), + (y_pred, y), + ), + allow_unused_input=allow_unused_input, + ) - g1_args = (g0.pool(),) if share_mem == "via pool()" else g0_args - g1.capture_begin(*g1_args) - c = a.clone() - for _ in range(5): - c = func_with_temps(c, 2) - g1.capture_end() + real_inputs = [torch.rand_like(x) for _ in range(10)] + real_targets = [torch.rand_like(y) for _ in range(10)] - # To reproduce data corruption, I need g0 and g1's kernels to run concurrently. - # But replay() (especially cudaGraphLaunch) can incur significant CPU overhead. - # The following pattern helps align device-side execution of g0 and g1's kernels. - torch.cuda.synchronize() - with torch.cuda.stream(s0): - torch.cuda._sleep(1000000) - s1.wait_stream(s0) - g0.replay() - with torch.cuda.stream(s1): - g1.replay() - torch.cuda.current_stream().wait_stream(s0) - torch.cuda.current_stream().wait_stream(s1) + for m, opt, relu, loss_fn in zip( + (model_graphed, model_control), + (opt_graphed, opt_control), + (relu_graphed, relu_control), + (loss_fn_graphed, loss_fn_control), + ): + # Resets RNC states before iterations for graphed and ungraphed models, + # so dropout math should be bitwise identical for both. + torch.manual_seed(5) + torch.cuda.manual_seed(5) + for data, target in zip(real_inputs, real_targets): + opt.zero_grad(set_to_none=True) + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + y_pred = m({"x": data, "unused_input": unused_input})["output"] + y_pred = relu(y_pred) + loss = loss_fn(y_pred, target) + loss.backward() + opt.step() - if (not TEST_CUDAMALLOCASYNC) and (share_mem != "Don't share"): - # If we used the native allocator and shared mempools, - # we expect the concurrent replays corrupted each other. - self.assertNotEqual(b.sum().item(), size * 94) - self.assertNotEqual(c.sum().item(), size * 156) - else: - # If we EITHER - # - used the native allocator without sharing mempools, OR - # - used cudaMallocAsync, which ignores graph pool-sharing hints and should always be safe - # we don't expect memory corruption. - self.assertEqual(b.sum().item(), size * 94) - self.assertEqual(c.sum().item(), size * 156) + for p, pc in zip(model_graphed.parameters(), model_control.parameters()): + self.assertEqual(p, pc) - del a, b, c, g0, g1 - # Tensors used across streams (a, b, c) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + # We graphed the models in training mode. Eval should still run ungraphed. + model_graphed.eval() + model_control.eval() + self.assertEqual( + model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) + ) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_three_successive(self): - torch.cuda.empty_cache() - - size = 1000 - - s = torch.cuda.Stream() - - for share_mem in ("Don't share", "via pool()", "via graph_pool_handle()"): - a = torch.ones((size,), device="cuda") - - g0 = torch.cuda.CUDAGraph() - g1 = torch.cuda.CUDAGraph() - g2 = torch.cuda.CUDAGraph() + @parametrize( + "with_amp,cache_enabled,allow_unused_input", + [ + subtest((False, False, True), decorators=[skipIfRocm]), + subtest((True, False, True), decorators=[skipIfRocm]), + subtest((True, True, True), decorators=[unittest.expectedFailure]), + subtest((False, False, False), decorators=[skipIfRocm]), + ], + name_fn=lambda x, y, z: "{}{}{}".format( + {True: "with_amp", False: "without_amp"}[x], + {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", + {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], + ), + ) + @serialTest() + def test_graph_make_graphed_callables_parameterless_nograd_module( + self, with_amp, cache_enabled, allow_unused_input + ): + torch.manual_seed(5) + torch.cuda.manual_seed(5) - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - g0_args = ( - (torch.cuda.graph_pool_handle(),) - if share_mem == "via graph_pool_handle()" - else () + N, D_in, H, D_out = 640, 4096, 2048, 1024 + + class ParameterlessModule(torch.nn.Module): + def forward(self, input_dict: dict): + x = input_dict["x"] + idx = ( + torch.arange(x.size(0), device=x.device) + .view(-1, 1) + .repeat(1, x.size(1)) ) - g0.capture_begin(*g0_args) - b = a.clone() - c = b + 1 - d = b + 2 - g0.capture_end() + return {"output": torch.gather(x, 0, idx)} - args = (g0.pool(),) if share_mem == "via pool()" else g0_args + models = [] + for _ in range(2): + model_section1 = ParameterlessModule().cuda() + models.append(torch.nn.Sequential(model_section1)) - g1.capture_begin(*args) - e = c + 3 - del c - g1.capture_end() + model_graphed = models[0] + model_control = models[1] - g2.capture_begin(*args) - f = d + 4 - g2.capture_end() - torch.cuda.current_stream().wait_stream(s) + model_graphed.load_state_dict(model_control.state_dict()) - # Tests that replaying in capture order is valid - g0.replay() - g1.replay() - g2.replay() + x = torch.randn(N, D_in, device="cuda", requires_grad=False) + unused_input = torch.randn(N, H, device="cuda", requires_grad=False) + y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False) + y = torch.randn(N, D_in, device="cuda") - self.assertEqual(e.sum().item(), size * 5) - self.assertEqual(f.sum().item(), size * 7) + # This is a good stress test. It graphs four callables: two Modules and two python functions. + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + model_graphed[0] = torch.cuda.make_graphed_callables( + model_graphed[0], + ({"x": x, "unused_input": unused_input},), + allow_unused_input=allow_unused_input, + ) - # Tests that replaying as g0, g2, g1 is only valid if they don't share a pool - g0.replay() - g2.replay() - g1.replay() + real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)] + real_targets = [torch.rand_like(y) for _ in range(10)] - expect_corruption = (not TEST_CUDAMALLOCASYNC) and ( - share_mem != "Don't share" - ) - # If we used the native allocator and shared mempools, g2's capture should have reused c's memory for f. - # We replayed g2 then g1, so we expect g1's captured "e = c + 3" mistakenly filled e with "f's vals + 3". - self.assertEqual( - e.sum().item(), size * (7 + 3) if expect_corruption else size * 5 - ) - self.assertEqual(f.sum().item(), size * 7) + for m in (model_graphed, model_control): + # Resets RNC states before iterations for graphed and ungraphed models, + # so dropout math should be bitwise identical for both. + torch.manual_seed(5) + torch.cuda.manual_seed(5) + for data, _ in zip(real_inputs, real_targets): + with torch.amp.autocast( + device_type="cuda", enabled=with_amp, cache_enabled=cache_enabled + ): + out = m({"x": data, "unused_input": unused_input})["output"] - del a, b, d, e, f, g0, g1, g2 - # Tensors used across streams (a, e, f) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + # We graphed the models in training mode. Eval should still run ungraphed. + model_graphed.eval() + model_control.eval() + self.assertEqual( + model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) + ) @unittest.skipIf( - (not TEST_CUDA_GRAPH) or TEST_CUDAMALLOCASYNC, - "CUDA >= 11.0 or ROCM >= 5.3 required for graphs", + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_memory_stats_and_use_result_after_destroy_graph(self): - kSmallSize = 1048576 - kSmallBuffer = 2097152 - kLargeBuffer = 20971520 - kMinLargeAlloc = 10485760 - kRoundLarge = 2097152 - - elem = 4 + def test_graph_make_graphed_callables_same_pool(self): + torch.manual_seed(5) + torch.cuda.manual_seed(5) + models = [] + num_models = 3 + for _ in range(num_models): + models.append( + torch.nn.Sequential( + torch.nn.Linear(32, 128), + torch.nn.ReLU(), + torch.nn.Linear(128, 128), + ).cuda() + ) + # we will reuse the same pool for all graph captures + mempool = torch.cuda.graph_pool_handle() + graphed_models = [] + for model in models: + x = torch.randn([64, 32], device="cuda") + graphed_model = deepcopy(model) + graphed_model = torch.cuda.make_graphed_callables( + graphed_model, (x,), pool=mempool + ) + graphed_models.append(graphed_model) - # this was annoying to write but stresses the expectations pretty rigorously - cases = ( - (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), - (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), - ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), - ( - (kMinLargeAlloc - 512) // elem, - 2, - 2 * kLargeBuffer, - kLargeBuffer, - "large_pool", - ), - ( - (kMinLargeAlloc + 512) // elem, - 3, - 3 - * ( - kRoundLarge - * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge) - ), - kRoundLarge * ((kMinLargeAlloc + 512 + kRoundLarge - 1) // kRoundLarge), - "large_pool", - ), - ) + for model, graphed_model in zip(models, graphed_models): + x = torch.randn([64, 32], device="cuda") + y = model(x) + yg = graphed_model(x) + l = y.norm() + lg = yg.norm() + l.backward() + lg.backward() - stats_to_check = ("segment.", "reserved_bytes.", "active.", "active_bytes.") + self.assertEqual(y, yg) + self.assertEqual(l, lg) + for p, pg in zip(model.parameters(), graphed_model.parameters()): + self.assertEqual(p, pg) + self.assertEqual(p.grad, pg.grad) + self.assertNotEqual(p.data_ptr(), pg.data_ptr()) + self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr()) - gc.collect() - torch.cuda.empty_cache() + def _test_graphed_optimizer( + self, steps_warmup, steps_train, optimizer_ctor, kwargs + ): + for actually_do_graphs in (True, False): + params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [ + torch.randn((), device="cuda") + ] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] - s = torch.cuda.Stream() + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] - for ( - numel, - delta_cudaMallocs, - delta_cudaMalloc_bytes, - delta_cudaMalloc_bytes_post_del_g, - pool_string, - ) in cases: - if pool_string == "small_pool": - delta_active_blocks = 3 # one from "b" plus a sneaky two from CUDAGraph's one-element rng seed and offset holders - delta_active_bytes = ( - numel * elem + 1024 - ) # + 1024 for CUDAGraph's rng seed and offset holders each - else: - delta_active_blocks = 1 # We only check the large pool, which isn't affected by rng offset holder - delta_active_bytes = numel * elem + # Control (capturable=False) - g = torch.cuda.CUDAGraph() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - # Allocation stat estimates assume input is created on the same stream as capture_begin() - # (in other words, the same stream silo as the rng offset holder, which is not allocated from the - # capture's private pool). - a = torch.ones((numel,), device="cuda") + opt = optimizer_ctor(params_control, capturable=False, **kwargs) - precapture_stats = torch.cuda.memory_stats() + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads[i][j] + opt.step() - g.capture_begin() - b = a.clone() - for _ in range(5): - b = b.clone() + 1 - g.capture_end() - torch.cuda.current_stream().wait_stream(s) + # capturable=True - gc.collect() + opt = optimizer_ctor(params_graphed, capturable=True, **kwargs) - postcapture_stats = torch.cuda.memory_stats() + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads[i][j] + opt.step() - expecteds = ( - delta_cudaMallocs, - delta_cudaMalloc_bytes, - delta_active_blocks, - delta_active_bytes, - ) - # Double checks replay and stats before and after a call to empty_cache - for i in range(2): - for stat, expected in zip(stats_to_check, expecteds): - stat = stat + pool_string + ".current" - current = postcapture_stats[stat] - precapture_stats[stat] + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + opt.step() - # There will only ever be one expandable segment in each of the small and large pools. The way the - # bookeeping is done in the allocator means that we never increment the number of segments. - if self.expandable_segments and "segment" in stat: - expected = 0 - # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an - # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is - # smaller than the page size - if ( - self.expandable_segments - and "reserved" in stat - and (numel == cases[3][0] or numel == cases[4][0]) - ): - expected = 2 * kLargeBuffer - - self.assertEqual( - current, - expected, - "Pre to post capture delta of " - + stat - + f" = {current}, expected = {expected}, numel = {numel}", - ) - - g.replay() - self.assertEqual(b.sum().item(), 6 * numel) - if i == 0: - torch.cuda.empty_cache() + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads[i + steps_warmup][j] + opt.step() - del g - gc.collect() - torch.cuda.empty_cache() - postdel_stats = torch.cuda.memory_stats() + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) - # Uses graph result b after graph has been deleted - self.assertEqual(b.sum().item(), 6 * numel) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_optims_with_explicitly_capturable_param_groups(self): + # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__ + n_warmup, n_replay = 3, 2 + for optimizer, second_param_group_capturable in product( + ( + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.ASGD, + torch.optim.Adamax, + torch.optim.NAdam, + torch.optim.RAdam, + torch.optim.Adadelta, + torch.optim.RMSprop, + torch.optim.Rprop, + ), + (True, False), + ): + ref_p1, param1 = ( + torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) + ) + ref_p2, param2 = ( + torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) + ) + grads1, grads2 = ( + [torch.randn_like(param1) for _ in range(n_warmup + n_replay)] + for _ in range(2) + ) + ref_grads1, ref_grads2 = ( + [t.clone() for t in tensors] for tensors in (grads1, grads2) + ) + params = [ + {"params": [param1], "capturable": True}, + {"params": [param2], "capturable": second_param_group_capturable}, + ] + opt = optimizer(params) + opt_ = optimizer( + [ + {"params": [ref_p1], "capturable": False}, + {"params": [ref_p2], "capturable": False}, + ] + ) - # b should be the only live reference remaining from the graph's private pool - expecteds = (1, delta_cudaMalloc_bytes_post_del_g, 1, numel * elem) - for stat, expected in zip(stats_to_check, expecteds): - stat = stat + pool_string + ".current" - current = postdel_stats[stat] - precapture_stats[stat] + for i in range(n_warmup + n_replay): + ref_p1.grad = ref_grads1[i] + ref_p2.grad = ref_grads2[i] + opt_.step() - # There will only ever be one expandable segment in each of the small and large pools. The way the - # bookeeping is done in the allocator means that we never increment the number of segments. - if self.expandable_segments and "segment" in stat: - expected = 0 - # These two cases hit an edge case where the PyTorch allocator won't immediately unmap part of an - # expandable segment (and as a result reduce the number of reserved bytes) if the block to unmap is - # smaller than the page size - if ( - self.expandable_segments - and "reserved" in stat - and numel == cases[3][0] - ): - expected = 2 * kLargeBuffer - if ( - self.expandable_segments - and "reserved" in stat - and numel == cases[4][0] - ): - expected = kLargeBuffer + for i in range(n_warmup): + param1.grad = grads1[i] + param2.grad = grads2[i] + opt.step() - self.assertEqual( - current, - expected, - "Pre capture to post graph delete delta of " - + stat - + f" = {current}, expected = {expected}, numel = {numel}", - ) + g = torch.cuda.CUDAGraph() + if not second_param_group_capturable: + with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"): + with torch.cuda.graph(g): + opt.step() + else: + with torch.cuda.graph(g): + opt.step() - # del a, b before the next case is essential, otherwise overwriting a and b in the next case - # can throw off its allocation/deallocation counts. - del a, b - # Tensors used across streams (a and b) were held until just now, so no need to call record_stream on them. - torch.cuda.synchronize() - torch.cuda.empty_cache() + for i in range(n_replay): + param1.grad.copy_(grads1[n_warmup + i]) + param2.grad.copy_(grads2[n_warmup + i]) + g.replay() + self.assertEqual(ref_p1, param1) + self.assertEqual(ref_p2, param2) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - def test_graph_record_stream(self): - # Makes sure graph capture defers attempting to reclaim allocations used across streams. See - # "Q. Why skip process_events if a capture might be underway?" in c10/cuda/CUDACachingAllocator.cpp - torch.cuda.empty_cache() + def test_cuda_graph_error_options(self): + def fn(): + x = torch.zeros([2000], device="cuda") + y = x + x + x + return y - potential_problem = torch.zeros((3,), device="cuda") - a = torch.zeros((3,), device="cuda") - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() - s2 = torch.cuda.Stream() - g = torch.cuda.CUDAGraph() + mem = None - torch.cuda.synchronize() - with torch.cuda.stream(s0): - potential_problem.record_stream(s0) - torch.cuda._sleep(TestCuda.FIFTY_MIL_CYCLES) - potential_problem.fill_(1.0) - del potential_problem + def raw_malloc(): + global mem + mem = None + stream = torch.cuda.Stream() + try: + with torch.cuda.stream(stream): + mem = torch.cuda.caching_allocator_alloc(1024) + except BaseException: + if mem is None: + return + try: + torch.cuda.caching_allocator_delete(mem) + mem = None + return None + except BaseException: + pass - with torch.cuda.stream(s1): - g.capture_begin() - # potential_problem's allocation should still be outstanding. if DeviceCachingAllocator::malloc - # mistakenly calls process_events, it will trigger cudaEventQueries on potential_problem's end-of-life - # event, which will cause the capture to error. - b = a.clone() + def throws_on_cuda_event(capture_error_mode): + graph = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + fn() + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + try: + with torch.cuda.graph( + graph, stream=stream, capture_error_mode=capture_error_mode + ): + out = fn() + thread = threading.Thread(target=raw_malloc) + thread.start() + thread.join() + except Exception: + if mem is not None: + torch.cuda.caching_allocator_delete(mem) + return True - # Let's also see what happens if we record_stream on a tensor during capture. - s2.wait_stream(s1) - with torch.cuda.stream(s2): - b.fill_(1.0) - b.record_stream(s2) # dummy record_stream - del b - s1.wait_stream(s2) - g.capture_end() - torch.cuda.synchronize() + return False - # dummy allocation triggers process_events, Hopefully successfully processes b's end-of-life event. - c = torch.zeros((3,), device="cuda") + self.assertFalse(throws_on_cuda_event("thread_local")) + self.assertFalse(throws_on_cuda_event("relaxed")) + + # Exception would Corrupt Process and make other tests fail + # self.assertTrue(throws_on_cuda_event("global")) - @skipIfRocm @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) - # If this test is the first in the process to try cudnn rnns with dropout, it'll initialize - # DropoutState's long-lived internal buffer. Calling code perceives this (correct) behavior - # as a memory leak unless we skip the leak check. - @skipCUDAMemoryLeakCheckIf(True) - @serialTest() - def test_graph_cudnn_dropout(self): - # Tests the interaction of cuda graph capture with DropoutState's syncs in ATen/native/cudnn/RNN.cpp. - # In particular, if user runs a sequence of captured and noncaptured cudnn rnns, DropoutState should - # avoid syncing noncapturing streams with captured events or vice versa. - torch.cuda.empty_cache() - - model = torch.nn.LSTM(512, 512, 2, dropout=0.5).cuda() - x = torch.ones(100, 192, 512, device="cuda") - - y = model(x) - + def test_cuda_graph_allocator_propagates_stream(self): + segments = torch.cuda.memory_snapshot() + existing_pools = {s["segment_pool_id"] for s in segments} + x = torch.randn(10240000, device="cuda") + y = torch.rand_like(x) g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + s0.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s0): g.capture_begin() - y = model(x) + z = x + y + with torch.cuda.stream(s1): + s1.wait_stream(s0) + w = z + y + s0.wait_stream(s1) + with torch.cuda.stream(s0): g.capture_end() - torch.cuda.current_stream().wait_stream(s) + segments = torch.cuda.memory_snapshot() + x = [ + s["segment_pool_id"] + for s in segments + if s["segment_pool_id"] not in existing_pools + ] + self.assertEqual(len(x), 2) + self.assertEqual(x[0], x[1]) - g.replay() + def test_batch_norm_gather_stats(self): + input = torch.randn(1, 3, 3, 3, device="cuda") + mean, invstd = torch.batch_norm_gather_stats( + input, + mean=torch.ones(2, 3, device="cuda"), + invstd=torch.ones(2, 3, device="cuda"), + running_mean=None, + running_var=None, + momentum=0.1, + eps=1e-5, + count=2, + ) + self.assertEqual(mean, torch.ones(3, device="cuda")) + self.assertEqual(invstd, torch.ones(3, device="cuda")) - y = model(x) + def test_matmul_memory_use(self): + def get_max_used(): + torch.cuda.synchronize() + val = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + return val - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @parametrize( - "with_amp,cache_enabled,allow_unused_input", - [ - subtest((False, False, True), decorators=[skipIfRocm]), - subtest((True, False, True), decorators=[skipIfRocm]), - subtest((True, True, True), decorators=[unittest.expectedFailure]), - subtest((False, False, False), decorators=[unittest.expectedFailure]), - ], - name_fn=lambda x, y, z: "{}{}{}".format( - {True: "with_amp", False: "without_amp"}[x], - {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", - {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], - ), - ) - @serialTest() - def test_graph_make_graphed_callables( - self, with_amp, cache_enabled, allow_unused_input - ): - torch.manual_seed(5) - torch.cuda.manual_seed(5) + a = torch.rand(1, 32, 32, device="cuda") + b = torch.rand(24, 32, 1, device="cuda") - N, D_in, H, D_out = 640, 4096, 2048, 1024 + get_max_used() - class MLP1(torch.nn.Module): - def __init__(self, D_in: int, H: int, D_out: int): - super().__init__() - self.net_1 = torch.nn.Sequential( - torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) - ).cuda() - self.net_2 = torch.nn.Sequential( - torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) - ).cuda() + torch.matmul(a, b) - def forward(self, input_dict: dict): - x = input_dict["x"] - return self.net_2(self.net_1(x)) + matmul_mem = get_max_used() - class MLP2(torch.nn.Module): - def __init__(self, D_in: int, H: int, D_out: int): - super().__init__() - self.net_1 = torch.nn.Sequential( - torch.nn.Linear(D_in, H), torch.nn.Dropout(p=0.1) - ).cuda() - self.net_2 = torch.nn.Sequential( - torch.nn.Linear(H, D_out), torch.nn.Dropout(p=0.2) - ).cuda() + a = a.expand(24, 32, 32) + torch.matmul(a, b) - def forward(self, x): - return self.net_2(self.net_1(x)) + matmul_expand_mem = get_max_used() - class ParameterlessModule(torch.nn.Module): - def forward(self, x): - idx = ( - torch.arange(x.size(0), device=x.device) - .view(-1, 1) - .repeat(1, x.size(1)) - ) - return {"output": torch.gather(x, 0, idx)} + torch.bmm(a, b) - models = [] - for _ in range(2): - model_section1 = MLP1(D_in, H, H).cuda() - model_section2 = MLP2(H, H, D_out).cuda() - model_section3 = ParameterlessModule().cuda() - models.append( - torch.nn.Sequential(model_section1, model_section2, model_section3) - ) + bmm_mem = get_max_used() - model_graphed = models[0] - model_control = models[1] + self.assertEqual(matmul_expand_mem, matmul_mem) + self.assertEqual(bmm_mem, matmul_mem) - model_graphed.load_state_dict(model_control.state_dict()) + @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test") + def test_rocm_backward_pass_guard(self): + # The test exercises a ROCm-specific feature. - opt_graphed = torch.optim.SGD(model_graphed.parameters(), lr=0.1) - opt_control = torch.optim.SGD(model_control.parameters(), lr=0.1) + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, constant): + self.assertFalse(torch._C._rocm_is_backward_pass()) + ctx.constant = constant + return tensor * constant - x = torch.randn(N, D_in, device="cuda") - h = torch.randn(N, H, device="cuda", requires_grad=True) - h2 = torch.randn(N, D_out, device="cuda", requires_grad=True) - unused_input = torch.randn(N, H, device="cuda", requires_grad=True) - y_pred = torch.randn(N, D_out, device="cuda", requires_grad=True) - y = torch.randn(N, D_out, device="cuda") + @staticmethod + def backward(ctx, grad_output): + self.assertTrue(torch._C._rocm_is_backward_pass()) + return grad_output * ctx.constant, None - loss_fn_control = torch.nn.functional.mse_loss - relu_control = torch.nn.functional.relu + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) - # This is a good stress test. It graphs four callables: two Modules and two python functions. - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - ( - model_graphed[0], - model_graphed[1], - model_graphed[2], - relu_graphed, - loss_fn_graphed, - ) = torch.cuda.make_graphed_callables( - ( - model_graphed[0], - model_graphed[1], - model_graphed[2], - relu_control, - loss_fn_control, - ), - ( - ({"x": x, "unused_input": unused_input},), - (h,), - (h2,), - (y_pred,), - (y_pred, y), - ), - allow_unused_input=allow_unused_input, - ) + def forward(self, x): + return MyFunction.apply(x, self.a) - real_inputs = [torch.rand_like(x) for _ in range(10)] - real_targets = [torch.rand_like(y) for _ in range(10)] + model = MyModule() + criterion = torch.nn.MSELoss(reduction="sum") + optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) - for m, opt, relu, loss_fn in zip( - (model_graphed, model_control), - (opt_graphed, opt_control), - (relu_graphed, relu_control), - (loss_fn_graphed, loss_fn_control), + x = torch.randn(5, 5) + result = model(x) + loss = criterion(result, x) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + def test_matmul_device_mismatch(self): + cpu = torch.rand((10, 10)) + cuda = cpu.cuda() + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" ): - # Resets RNC states before iterations for graphed and ungraphed models, - # so dropout math should be bitwise identical for both. - torch.manual_seed(5) - torch.cuda.manual_seed(5) - for data, target in zip(real_inputs, real_targets): - opt.zero_grad(set_to_none=True) - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - y_pred = m({"x": data, "unused_input": unused_input})["output"] - y_pred = relu(y_pred) - loss = loss_fn(y_pred, target) - loss.backward() - opt.step() + cpu @ cuda + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): + cuda @ cpu - for p, pc in zip(model_graphed.parameters(), model_control.parameters()): - self.assertEqual(p, pc) + for s, m1, m2 in product((cpu, cuda), repeat=3): + if s.device == m1.device == m2.device: + torch.addmm(s, m1, m2) + else: + with self.assertRaisesRegex( + RuntimeError, "Expected all tensors to be on the same device" + ): + torch.addmm(s, m1, m2) - # We graphed the models in training mode. Eval should still run ungraphed. - model_graphed.eval() - model_control.eval() - self.assertEqual( - model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) + @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") + def test_lazy_init(self): + """Validate that no CUDA calls are made during `import torch` call""" + + def check_output(script: str) -> str: + return ( + subprocess.check_output([sys.executable, "-c", script]) + .decode("ascii") + .strip() + ) + + VISIBLE_DEVICES = ( + "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES" ) + test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())" + rc = check_output(test_script) + self.assertEqual(rc, "0") + if not TEST_WITH_ROCM: + # Check that `cuInit` was not called during the import + # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3 + # See https://github.com/pytorch/pytorch/issues/116276 for more details + libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll" + cuda_driver_api_call = ( + f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))" + ) + rc = check_output( + f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})" + ) + self.assertEqual(rc, "3") - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - @parametrize( - "with_amp,cache_enabled,allow_unused_input", - [ - subtest((False, False, True), decorators=[skipIfRocm]), - subtest((True, False, True), decorators=[skipIfRocm]), - subtest((True, True, True), decorators=[unittest.expectedFailure]), - subtest((False, False, False), decorators=[skipIfRocm]), - ], - name_fn=lambda x, y, z: "{}{}{}".format( - {True: "with_amp", False: "without_amp"}[x], - {True: "_cache_enabled", False: "_cache_disabled"}[y] if x else "", - {True: "_allow_unused_input", False: "_not_allow_unused_input"}[z], - ), - ) - @serialTest() - def test_graph_make_graphed_callables_parameterless_nograd_module( - self, with_amp, cache_enabled, allow_unused_input - ): - torch.manual_seed(5) - torch.cuda.manual_seed(5) - - N, D_in, H, D_out = 640, 4096, 2048, 1024 - - class ParameterlessModule(torch.nn.Module): - def forward(self, input_dict: dict): - x = input_dict["x"] - idx = ( - torch.arange(x.size(0), device=x.device) - .view(-1, 1) - .repeat(1, x.size(1)) - ) - return {"output": torch.gather(x, 0, idx)} - - models = [] - for _ in range(2): - model_section1 = ParameterlessModule().cuda() - models.append(torch.nn.Sequential(model_section1)) - - model_graphed = models[0] - model_control = models[1] + @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") + def test_hip_device_count(self): + """Validate device_count works with both CUDA/HIP visible devices""" + test_script = """\ +import torch +import os +print(f"{torch.cuda.device_count()}") +""" + custom_envs = [ + {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, + {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, + {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, + ] - model_graphed.load_state_dict(model_control.state_dict()) + for env_config in custom_envs: + env = os.environ.copy() + for key, value in env_config.items(): + if value is None: + env.pop(key, None) + else: + env[key] = value + r = ( + subprocess.check_output([sys.executable, "-c", test_script], env=env) + .decode("ascii") + .strip() + ) + self.assertEqual("1", r) - x = torch.randn(N, D_in, device="cuda", requires_grad=False) - unused_input = torch.randn(N, H, device="cuda", requires_grad=False) - y_pred = torch.randn(N, D_in, device="cuda", requires_grad=False) - y = torch.randn(N, D_in, device="cuda") + @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices") + def test_device_count_not_cached_pre_init(self): + visible_devices = ( + "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES" + ) + test_script = f"""\ +import torch +import os +r1 = torch.cuda.device_count() +os.environ['{visible_devices}'] = '0' +r2 = torch.cuda.device_count() +torch.empty(10, device='cuda') +print(f"{{r1}}, {{r2}}") +""" - # This is a good stress test. It graphs four callables: two Modules and two python functions. - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - model_graphed[0] = torch.cuda.make_graphed_callables( - model_graphed[0], - ({"x": x, "unused_input": unused_input},), - allow_unused_input=allow_unused_input, - ) + r = ( + subprocess.check_output([sys.executable, "-c", test_script]) + .decode("ascii") + .strip() + ) - real_inputs = [torch.rand_like(x, requires_grad=True) for _ in range(10)] - real_targets = [torch.rand_like(y) for _ in range(10)] + x = torch.cuda.device_count() + self.assertEqual(f"{x}, 1", r) - for m in (model_graphed, model_control): - # Resets RNC states before iterations for graphed and ungraphed models, - # so dropout math should be bitwise identical for both. - torch.manual_seed(5) - torch.cuda.manual_seed(5) - for data, target in zip(real_inputs, real_targets): - with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled): - out = m({"x": data, "unused_input": unused_input})["output"] + @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") + def test_gds_fails_in_ci(self): + if IS_WINDOWS or TEST_WITH_ROCM: + error_msg = "is not supported on this platform" + else: + error_msg = "cuFileHandleRegister failed" + with TemporaryFileName() as f: + with self.assertRaisesRegex(RuntimeError, error_msg): + file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) - # We graphed the models in training mode. Eval should still run ungraphed. - model_graphed.eval() - model_control.eval() - self.assertEqual( - model_graphed({"x": real_inputs[0]}), model_control({"x": real_inputs[0]}) - ) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +@torch.testing._internal.common_utils.markDynamoStrictTest +class TestCudaMallocAsync(TestCase): @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_graph_make_graphed_callables_same_pool(self): - torch.manual_seed(5) - torch.cuda.manual_seed(5) - models = [] - num_models = 3 - for _ in range(num_models): - models.append( - torch.nn.Sequential( - torch.nn.Linear(32, 128), - torch.nn.ReLU(), - torch.nn.Linear(128, 128), - ).cuda() - ) - # we will reuse the same pool for all graph captures - mempool = torch.cuda.graph_pool_handle() - graphed_models = [] - for model in models: - x = torch.randn([64, 32], device="cuda") - graphed_model = deepcopy(model) - graphed_model = torch.cuda.make_graphed_callables( - graphed_model, (x,), pool=mempool - ) - graphed_models.append(graphed_model) + def test_memory_snapshot(self): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history("state", stacks="python") + # make x the second block in a segment + torch.rand(2 * 311, 411, device="cuda") + unused = torch.rand(310, 410, device="cuda") + x = torch.rand(311, 411, device="cuda") - for model, graphed_model in zip(models, graphed_models): - x = torch.randn([64, 32], device="cuda") - y = model(x) - yg = graphed_model(x) - l = y.norm() - lg = yg.norm() - l.backward() - lg.backward() + # create a bunch of tensors that all will tile into the + # same segment to exercise the history merging code + # 512B is the minimum block size, + # so we allocate all the tensors to this size to make sure + # they tile evenly + tensors = [torch.rand(128, device="cuda") for _ in range(1000)] + while tensors: + del tensors[randint(0, len(tensors) - 1)] - self.assertEqual(y, yg) - self.assertEqual(l, lg) - for p, pg in zip(model.parameters(), graphed_model.parameters()): - self.assertEqual(p, pg) - self.assertEqual(p.grad, pg.grad) - self.assertNotEqual(p.data_ptr(), pg.data_ptr()) - self.assertNotEqual(p.grad.data_ptr(), pg.grad.data_ptr()) + # exercise the history trimming code + torch.rand(128 * 5, device="cuda") - def _test_graphed_optimizer( - self, steps_warmup, steps_train, optimizer_ctor, kwargs - ): - for actually_do_graphs in (True, False): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] + [ - torch.randn((), device="cuda") - ] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] + ss = torch.cuda.memory._snapshot() + found_it = False + for seg in ss["segments"]: + self.assertTrue("frames" in seg) + for b in seg["blocks"]: + if b["requested_size"] == 311 * 411 * 4: + self.assertTrue("test_cuda" in b["frames"][0]["filename"]) + found_it = True + self.assertEqual(x.untyped_storage().data_ptr(), b["address"]) + self.assertTrue(found_it) - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] + if not IS_WINDOWS: + with tempfile.NamedTemporaryFile() as f: + torch.cuda.memory._save_segment_usage(f.name) + with open(f.name) as f2: + self.assertTrue("test_cuda.py" in f2.read()) + del unused + del x + torch.cuda.empty_cache() + ss = torch.cuda.memory._snapshot() + self.assertTrue( + ss["device_traces"][0][-1]["action"] + in ("segment_free", "segment_unmap") + ) - # Control (capturable=False) + finally: + torch.cuda.memory._record_memory_history(None) - opt = optimizer_ctor(params_control, capturable=False, **kwargs) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") + def test_direct_traceback(self): + from torch._C._profiler import gather_traceback, symbolize_tracebacks # @manual - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads[i][j] - opt.step() + c = gather_traceback(True, True, True) + (r,) = symbolize_tracebacks([c]) + r = str(r) + self.assertTrue("test_cuda.py" in r) + self.assertTrue("unwind" in r) - # capturable=True + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_snapshot_with_cpp(self): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history("state", stacks="all") + x = torch.rand(311, 411, device="cuda") - opt = optimizer_ctor(params_graphed, capturable=True, **kwargs) + ss = torch.cuda.memory._snapshot()["segments"] + found_it = False + for seg in ss: + for b in seg["blocks"]: + if b["requested_size"] == 311 * 411 * 4: + self.assertTrue("::rand" in str(b["frames"])) + found_it = True + self.assertTrue(found_it) - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads[i][j] - opt.step() - - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - opt.step() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads[i + steps_warmup][j] - opt.step() + finally: + torch.cuda.memory._record_memory_history(None) - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) + @skipIfRocm + def test_memory_profiler_viz(self): + with torch.profiler.profile( + with_stack=True, profile_memory=True, record_shapes=True + ) as prof: + x = torch.rand(128, 128, device="cuda") + x * x + x * x + plot = profile_plot(prof) + plot = json.dumps(_profile_to_snapshot(prof)) + self.assertTrue("test_cuda.py" in plot) + self.assertTrue("test_memory_profiler_viz" in plot) + self.assertTrue("category" in plot) @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_graph_optims(self): - # Needs generalization if we want to extend this test to non-Adam-like optimizers. - cases = ( - [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "decoupled_weight_decay": decoupled_weight_decay, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product( - (torch.optim.NAdam, torch.optim.RAdam), - (False, True), - (False, True), - (0.0, 0.1), - ) - ] - + [ - ( - torch.optim.Rprop, - {"lr": 0.1, "foreach": foreach, "maximize": maximize}, - ) - for foreach, maximize in product( - (False, True), - (False, True), - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "betas": (0.8, 0.7), - "foreach": foreach, - "amsgrad": amsgrad, - }, - ) - for optimizer_ctor, foreach, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), - (False, True), - (False, True), - ) - ] - + [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] - + [ - ( - optimizer_ctor, - { - "lr": 0.1, - "foreach": foreach, - "maximize": maximize, - "weight_decay": weight_decay, - }, - ) - for optimizer_ctor, foreach, maximize, weight_decay in product( - ( - torch.optim.Adamax, - torch.optim.ASGD, - torch.optim.Adadelta, - torch.optim.RMSprop, - ), - (False, True), - (False, True), - (0, 0.1), - ) - ] - ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_cycles(self): + fired = False - for optimizer_ctor, kwargs in cases: - with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs): - self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs) + def observer(html): + nonlocal fired + fired = True + self.assertTrue("torch.Tensor" in html) + self.assertTrue("test_cuda" in html) + self.assertTrue("cell_contents" in html) - @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" - ) - def test_graph_optims_with_explicitly_capturable_param_groups(self): - # mimicking `_test_graphed_optimizer` maladroitly to pass two param_groups to optimizer.__init__ - n_warmup, n_replay = 3, 2 - for optimizer, second_param_group_capturable in product( - ( - torch.optim.Adam, - torch.optim.AdamW, - torch.optim.ASGD, - torch.optim.Adamax, - torch.optim.NAdam, - torch.optim.RAdam, - torch.optim.Adadelta, - torch.optim.RMSprop, - torch.optim.Rprop, - ), - (True, False), - ): - ref_p1, param1 = ( - torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) - ) - ref_p2, param2 = ( - torch.nn.Parameter(torch.ones(1, device="cuda")) for _ in range(2) - ) - grads1, grads2 = ( - [torch.randn_like(param1) for _ in range(n_warmup + n_replay)] - for _ in range(2) - ) - ref_grads1, ref_grads2 = ( - [t.clone() for t in tensors] for tensors in (grads1, grads2) - ) - params = [ - {"params": [param1], "capturable": True}, - {"params": [param2], "capturable": second_param_group_capturable}, - ] - opt = optimizer(params) - opt_ = optimizer( - [ - {"params": [ref_p1], "capturable": False}, - {"params": [ref_p2], "capturable": False}, - ] - ) + disarm = observe_tensor_cycles(observer) - for i in range(n_warmup + n_replay): - ref_p1.grad = ref_grads1[i] - ref_p2.grad = ref_grads2[i] - opt_.step() + def noop(): + pass - for i in range(n_warmup): - param1.grad = grads1[i] - param2.grad = grads2[i] - opt.step() + try: - g = torch.cuda.CUDAGraph() - if not second_param_group_capturable: - with self.assertRaisesRegex(RuntimeError, "Attempting CUDA graph"): - with torch.cuda.graph(g): - opt.step() - else: - with torch.cuda.graph(g): - opt.step() + def create(): + x = torch.empty(3, 4, device="cuda") - for i in range(n_replay): - param1.grad.copy_(grads1[n_warmup + i]) - param2.grad.copy_(grads2[n_warmup + i]) - g.replay() - self.assertEqual(ref_p1, param1) - self.assertEqual(ref_p2, param2) + def foo(p): + if p: + return foo(not p) + else: + return x + + return foo + + create() + gc.collect() + # the callback has to run outside of the collect + # call so it doesn't actual fire until the next + # method call after a gc.collect + noop() + self.assertTrue(fired) + finally: + disarm() @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_graph_scaling_fused_optimizers(self): - cases = [ - ( - optimizer_ctor, - {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}, - ) - for optimizer_ctor, amsgrad in product( - (torch.optim.Adam, torch.optim.AdamW), (False, True) - ) - ] + list( - product( - (torch.optim.SGD,), - [ - { - "lr": 0.1, - "momentum": 0.0, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) - ] - + [ - { - "lr": 0.1, - "momentum": 0.5, - "dampening": d, - "weight_decay": w, - "nesterov": n, - "fused": True, - } - for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) - ], - ) - ) - - steps_warmup = 3 - steps_train = 2 + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots(self): + for context, stacks in ( + ("all", "all" if IS_LINUX else "python"), + ("all", "python"), + (None, "python"), + ): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history( + "all", context=context, stacks=stacks + ) - for OptClass, kwargs in cases: - has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) - for actually_do_graphs in (True, False) if has_capturable_arg else (True,): - params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] - params_control = [p.clone().requires_grad_() for p in params] - params_graphed = [p.clone().requires_grad_() for p in params] + def run(): + x = torch.rand(128, 128, device="cuda") + x * x + x * x - # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. - grads = [ - [torch.randn_like(p) for p in params] - for _ in range(steps_warmup + steps_train) - ] - with torch.no_grad(): - grads_control = [[g.clone() for g in gs] for gs in grads] - grads_graphed = [[g.clone() for g in gs] for gs in grads] + run() + cpp = stacks == "all" + record_context = context is not None + ss = torch.cuda.memory._snapshot() - # Gradient Scaler - scaler_for_control = torch.amp.GradScaler( - device="cuda", init_scale=128.0 - ) - with torch.no_grad(): - scaler_for_control._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) + tplot = trace_plot(ss) + splot = segment_plot(ss) + text = json.dumps(ss) - scaler_for_graphed = torch.amp.GradScaler(device="cuda") - scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) - with torch.no_grad(): - scaler_for_graphed._lazy_init_scale_growth_tracker( - torch.device("cuda") - ) + self.assertTrue(record_context == ("test_memory_plots" in text)) + self.assertTrue(cpp == ("::rand" in text)) + self.assertTrue(str(128 * 128 * 4) in text) - # Control (capturable=False) - if has_capturable_arg: - kwargs["capturable"] = False - opt = OptClass(params_control, **kwargs) + finally: + torch.cuda.memory._record_memory_history(None) - for i in range(steps_warmup + steps_train): - for j, p in enumerate(params_control): - p.grad = grads_control[i][j] - scaler_for_control.step(opt) - scaler_for_control.update() + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots_free_stack(self): + for context in ["alloc", "all", "state"]: + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history(context=context) + x = None - # capturable=True - if has_capturable_arg: - kwargs["capturable"] = True - opt = OptClass(params_graphed, **kwargs) + def thealloc(): + nonlocal x + x = torch.rand(3, 4, device="cuda") - for i in range(steps_warmup): - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() + def thefree(): + nonlocal x + del x - if actually_do_graphs: - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for i in range(steps_train): - if actually_do_graphs: - for j, p in enumerate(params_graphed): - p.grad.copy_(grads_graphed[i + steps_warmup][j]) - g.replay() - else: - # Passing capturable=True to the constructor and running without graphs should still be - # numerically correct, even if it's not ideal for performance. - for j, p in enumerate(params_graphed): - p.grad = grads_graphed[i + steps_warmup][j] - scaler_for_graphed.step(opt) - scaler_for_graphed.update() - - for p_control, p_graphed in zip(params_control, params_graphed): - self.assertEqual(p_control, p_graphed) + thealloc() + thefree() + ss = json.dumps(torch.cuda.memory._snapshot()) + self.assertTrue(("thefree" in ss) == (context == "all")) + self.assertTrue(("thealloc" in ss) == (context != "state")) + finally: + torch.cuda.memory._record_memory_history(None) @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_cuda_graph_error_options(self): - def fn(): - x = torch.zeros([2000], device="cuda") - y = x + x + x - return y - - mem = None + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots_history_context(self): + try: + torch.cuda.memory.empty_cache() + x = None - def raw_malloc(): - global mem - mem = None - stream = torch.cuda.Stream() - try: - with torch.cuda.stream(stream): - mem = torch.cuda.caching_allocator_alloc(1024) - except BaseException: - if mem is None: - return - try: - torch.cuda.caching_allocator_delete(mem) - mem = None - return None - except BaseException: - pass + def should_capture1(): + nonlocal x + x = torch.rand(4, 4, device="cuda") - def throws_on_cuda_event(capture_error_mode): - graph = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - fn() - stream.synchronize() - torch.cuda.current_stream().wait_stream(stream) - torch.cuda.synchronize() - try: - with torch.cuda.graph( - graph, stream=stream, capture_error_mode=capture_error_mode - ): - out = fn() - thread = threading.Thread(target=raw_malloc) - thread.start() - thread.join() - except Exception: - if mem is not None: - torch.cuda.caching_allocator_delete(mem) - return True + def should_not_capture(): + nonlocal x + x = torch.rand(3, 4, device="cuda") - return False + def should_capture2(): + nonlocal x + x = torch.rand(4, 4, device="cuda") - self.assertFalse(throws_on_cuda_event("thread_local")) - self.assertFalse(throws_on_cuda_event("relaxed")) + # Recording with context and python call stacks should capture the call stack. + torch.cuda.memory._record_memory_history(context="all", stacks="python") + should_capture1() + # Recording with context=None should not capture the call stack. + torch.cuda.memory._record_memory_history(context=None) + should_not_capture() + # Recording with context and python call stacks should capture the call stack. + torch.cuda.memory._record_memory_history(context="all", stacks="python") + should_capture2() - # Exception would Corrupt Process and make other tests fail - # self.assertTrue(throws_on_cuda_event("global")) + ss = json.dumps(torch.cuda.memory._snapshot()) + self.assertTrue("should_capture1" in ss) + self.assertTrue("should_not_capture" not in ss) + self.assertTrue("should_capture2" in ss) + finally: + torch.cuda.memory._record_memory_history(None) @unittest.skipIf( - not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - def test_cuda_graph_allocator_propagates_stream(self): - segments = torch.cuda.memory_snapshot() - existing_pools = {s["segment_pool_id"] for s in segments} - x = torch.randn(10240000, device="cuda") - y = torch.rand_like(x) - g = torch.cuda.CUDAGraph() - s0 = torch.cuda.Stream() - s1 = torch.cuda.Stream() - s0.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s0): - g.capture_begin() - z = x + y - with torch.cuda.stream(s1): - s1.wait_stream(s0) - w = z + y - s0.wait_stream(s1) - with torch.cuda.stream(s0): - g.capture_end() - segments = torch.cuda.memory_snapshot() - x = [ - s["segment_pool_id"] - for s in segments - if s["segment_pool_id"] not in existing_pools - ] - self.assertEqual(len(x), 2) - self.assertEqual(x[0], x[1]) + @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") + def test_memory_plots_free_segment_stack(self): + for context in ["alloc", "all", "state"]: + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history(context=context) + x = torch.rand(3, 4, device="cuda") + del x + torch.cuda.memory.empty_cache() - def test_batch_norm_gather_stats(self): - input = torch.randn(1, 3, 3, 3, device="cuda") - mean, invstd = torch.batch_norm_gather_stats( - input, - mean=torch.ones(2, 3, device="cuda"), - invstd=torch.ones(2, 3, device="cuda"), - running_mean=None, - running_var=None, - momentum=0.1, - eps=1e-5, - count=2, - ) - self.assertEqual(mean, torch.ones(3, device="cuda")) - self.assertEqual(invstd, torch.ones(3, device="cuda")) + ss = json.dumps(torch.cuda.memory._snapshot()) + self.assertTrue(("empty_cache" in ss) == (context == "all")) + finally: + torch.cuda.memory._record_memory_history(None) - def test_matmul_memory_use(self): - def get_max_used(): - torch.cuda.synchronize() - val = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - return val + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + def test_memory_snapshot_script(self): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history("state", stacks="python") - a = torch.rand(1, 32, 32, device="cuda") - b = torch.rand(24, 32, 1, device="cuda") + @torch.jit.script + def foo(): + return torch.rand(311, 411, device="cuda") - get_max_used() + x = foo() - torch.matmul(a, b) + ss = torch.cuda.memory._snapshot()["segments"] + found_it = False + for seg in ss: + for b in seg["blocks"]: + if b["requested_size"] == 311 * 411 * 4: + self.assertTrue(b["frames"][0]["name"] == "foo") + found_it = True + self.assertTrue(found_it) - matmul_mem = get_max_used() + finally: + torch.cuda.memory._record_memory_history(None) - a = a.expand(24, 32, 32) - torch.matmul(a, b) + def test_max_split_expandable(self): + torch.cuda.memory.empty_cache() + mb = 1024 * 1024 + _, all_memory = torch.cuda.memory.mem_get_info() + total_allowed = 120 * mb + fraction_allowed = total_allowed / all_memory + assert int(fraction_allowed * all_memory) == total_allowed + torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) - matmul_expand_mem = get_max_used() + def alloc(n): + return torch.ones(n * mb, dtype=torch.int8, device="cuda") - torch.bmm(a, b) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:False,max_split_size_mb:40" + ) + a = alloc(40) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:True,max_split_size_mb:40" + ) + b = alloc(40) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:False,max_split_size_mb:40" + ) + c = alloc(40) + with self.assertRaises(torch.OutOfMemoryError): + alloc(40) + del a, b, c + # force release_cached_blocks to run with some expandable segments in the free list + alloc(120) + + def test_garbage_collect_expandable(self): + torch.cuda.memory.empty_cache() + mb = 1024 * 1024 + _, all_memory = torch.cuda.memory.mem_get_info() + total_allowed = 120 * mb + fraction_allowed = total_allowed / all_memory + assert int(fraction_allowed * all_memory) == total_allowed + torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) - bmm_mem = get_max_used() + def alloc(n): + return torch.ones(n * mb, dtype=torch.int8, device="cuda") - self.assertEqual(matmul_expand_mem, matmul_mem) - self.assertEqual(bmm_mem, matmul_mem) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:False,garbage_collection_threshold:0.5" + ) + a = alloc(40) + torch.cuda.memory._set_allocator_settings( + "expandable_segments:True,garbage_collection_threshold:0.5" + ) + b = alloc(40) + del a, b + # causes GC to run. The expandable segment block will be split + # so GC would not attempt to free it anyway, but this at least makes sure + # expandable_segment blocks can be in the free list when this is called. + alloc(80) - @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-only test") - def test_rocm_backward_pass_guard(self): - # The test exercises a ROCm-specific feature. + def test_allocator_settings(self): + def power2_div(size, div_factor): + pow2 = 1 + while pow2 < size: + pow2 = pow2 * 2 + if pow2 == size: + return pow2 + step = pow2 / 2 / div_factor + ret = pow2 / 2 + while ret < size: + ret = ret + step + return ret - class MyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, tensor, constant): - self.assertFalse(torch._C._rocm_is_backward_pass()) - ctx.constant = constant - return tensor * constant + torch.cuda.memory.empty_cache() + key_allocated = ( + "active_bytes.all.allocated" + if not TEST_CUDAMALLOCASYNC + else "allocated_bytes.all.current" + ) + key_requested = "requested_bytes.all.allocated" - @staticmethod - def backward(ctx, grad_output): - self.assertTrue(torch._C._rocm_is_backward_pass()) - return grad_output * ctx.constant, None + nelems = 21 * 1024 * 1024 + nbytes = 4 * nelems # floats are 4 bytes - class MyModule(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) + nelems_big = 100 * 1024 * 1024 + nbytes_big = 4 * nelems_big # floats are 4 bytes - def forward(self, x): - return MyFunction.apply(x, self.a) + start_mem = torch.cuda.memory_stats()[key_allocated] + torch.cuda.memory._set_allocator_settings("") + x = torch.rand(nelems, device="cuda") - model = MyModule() - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) + # test roundup_power2_divisions single value syntax + reg_mem = torch.cuda.memory_stats()[key_allocated] + start_requested = torch.cuda.memory_stats()[key_requested] + torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4") + y = torch.rand(nelems, device="cuda") - x = torch.randn(5, 5) - result = model(x) - loss = criterion(result, x) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - def test_matmul_device_mismatch(self): - cpu = torch.rand((10, 10)) - cuda = cpu.cuda() - with self.assertRaisesRegex( - RuntimeError, "Expected all tensors to be on the same device" - ): - cpu @ cuda - with self.assertRaisesRegex( - RuntimeError, "Expected all tensors to be on the same device" - ): - cuda @ cpu - - for s, m1, m2 in product((cpu, cuda), repeat=3): - if s.device == m1.device == m2.device: - torch.addmm(s, m1, m2) - else: - with self.assertRaisesRegex( - RuntimeError, "Expected all tensors to be on the same device" - ): - torch.addmm(s, m1, m2) - - @unittest.skipIf(TEST_MULTIGPU, "Testing on one GPU is sufficient") - def test_lazy_init(self): - """Validate that no CUDA calls are made during `import torch` call""" + pow2_div4_mem = torch.cuda.memory_stats()[key_allocated] + current_requested = torch.cuda.memory_stats()[key_requested] - def check_output(script: str) -> str: - return ( - subprocess.check_output([sys.executable, "-c", script]) - .decode("ascii") - .strip() - ) + self.assertTrue(reg_mem - start_mem == nbytes) + if not TEST_CUDAMALLOCASYNC: + # not supported with the cudaMallocAsync backend + self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4)) + self.assertTrue(current_requested - start_requested == nbytes) - VISIBLE_DEVICES = ( - "HIP_VISIBLE_DEVICES" if TEST_WITH_ROCM else "CUDA_VISIBLE_DEVICES" + torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5") + torch.cuda.memory._set_allocator_settings( + "garbage_collection_threshold:0.5,max_split_size_mb:40" ) - test_script = f"import os; import torch;os.environ['{VISIBLE_DEVICES}']='32';print(torch.cuda.device_count())" - rc = check_output(test_script) - self.assertEqual(rc, "0") - if not TEST_WITH_ROCM: - # Check that `cuInit` was not called during the import - # By using ctypes and calling cuDeviceCountGet() and expect CUDA_ERROR_NOT_INITIALIZED == 3 - # See https://github.com/pytorch/pytorch/issues/116276 for more details - libcuda_name = "libcuda.so.1" if not IS_WINDOWS else "nvcuda.dll" - cuda_driver_api_call = ( - f"ctypes.CDLL('{libcuda_name}').cuDeviceGetCount(ctypes.byref(x))" - ) - rc = check_output( - f"import torch; import ctypes;x=ctypes.c_int(-1);print({cuda_driver_api_call})" - ) - self.assertEqual(rc, "3") - @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") - def test_hip_device_count(self): - """Validate device_count works with both CUDA/HIP visible devices""" - test_script = """\ -import torch -import os -print(f"{torch.cuda.device_count()}") -""" - custom_envs = [ - {"CUDA_VISIBLE_DEVICES": "0", "HIP_VISIBLE_DEVICES": None}, - {"CUDA_VISIBLE_DEVICES": None, "HIP_VISIBLE_DEVICES": "0"}, - {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "HIP_VISIBLE_DEVICES": "0"}, - ] - - for env_config in custom_envs: - env = os.environ.copy() - for key, value in env_config.items(): - if value is None: - env.pop(key, None) - else: - env[key] = value - r = ( - subprocess.check_output([sys.executable, "-c", test_script], env=env) - .decode("ascii") - .strip() - ) - self.assertEqual("1", r) + # should have reset the power2 divisions now + torch.cuda.memory.empty_cache() + start_mem = torch.cuda.memory_stats()[key_allocated] + z = torch.rand(nelems, device="cuda") + reg_mem = torch.cuda.memory_stats()[key_allocated] + self.assertTrue(reg_mem - start_mem == nbytes) - @unittest.skipIf(not TEST_MULTIGPU, "requires multiple devices") - def test_device_count_not_cached_pre_init(self): - visible_devices = ( - "HIP_VISIBLE_DEVICES" if torch.version.hip else "CUDA_VISIBLE_DEVICES" + # roundup_power2_divisions knob array syntax + torch.cuda.memory.empty_cache() + torch.cuda.memory._set_allocator_settings( + "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]" ) - test_script = f"""\ -import torch -import os -r1 = torch.cuda.device_count() -os.environ['{visible_devices}'] = '0' -r2 = torch.cuda.device_count() -torch.empty(10, device='cuda') -print(f"{{r1}}, {{r2}}") -""" + start_mem = torch.cuda.memory_stats()[key_allocated] + w = torch.rand(nelems, device="cuda") - r = ( - subprocess.check_output([sys.executable, "-c", test_script]) - .decode("ascii") - .strip() - ) + pow2_div8_mem = torch.cuda.memory_stats()[key_allocated] + if not TEST_CUDAMALLOCASYNC: + # not supported with the cudaMallocAsync backend + self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8)) - x = torch.cuda.device_count() - self.assertEqual(f"{x}, 1", r) + torch.cuda.memory.empty_cache() + start_mem = torch.cuda.memory_stats()[key_allocated] + v = torch.rand(nelems_big, device="cuda") - @unittest.skip("Disabling as USE_CUFILE=0 by default in builds") - def test_gds_fails_in_ci(self): - if IS_WINDOWS or TEST_WITH_ROCM: - error_msg = "is not supported on this platform" - else: - error_msg = "cuFileHandleRegister failed" - with TemporaryFileName() as f: - with self.assertRaisesRegex(RuntimeError, error_msg): - file = torch.cuda.gds._GdsFile(f, os.O_CREAT | os.O_RDWR) + pow2_div2_mem = torch.cuda.memory_stats()[key_allocated] + if not TEST_CUDAMALLOCASYNC: + # not supported with the cudaMallocAsync backend + self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2)) + torch.cuda.memory.empty_cache() + torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True") + start_mem = torch.cuda.memory_stats()[key_allocated] + w = torch.rand(nelems, device="cuda") + reg_mem = torch.cuda.memory_stats()[key_allocated] + self.assertTrue(reg_mem - start_mem == nbytes) -@torch.testing._internal.common_utils.markDynamoStrictTest -class TestCudaMallocAsync(TestCase): - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - def test_memory_snapshot(self): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history("state", stacks="python") - # make x the second block in a segment - torch.rand(2 * 311, 411, device="cuda") - unused = torch.rand(310, 410, device="cuda") - x = torch.rand(311, 411, device="cuda") + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings("foo:1,bar:2") - # create a bunch of tensors that all will tile into the - # same segment to exercise the history merging code - # 512B is the minimum block size, - # so we allocate all the tensors to this size to make sure - # they tile evenly - tensors = [torch.rand(128, device="cuda") for _ in range(1000)] - while tensors: - del tensors[randint(0, len(tensors) - 1)] + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings( + "garbage_collection_threshold:1.2" + ) - # exercise the history trimming code - torch.rand(128 * 5, device="cuda") + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings("max_split_size_mb:2") - ss = torch.cuda.memory._snapshot() - found_it = False - for seg in ss["segments"]: - self.assertTrue("frames" in seg) - for b in seg["blocks"]: - if b["requested_size"] == 311 * 411 * 4: - self.assertTrue("test_cuda" in b["frames"][0]["filename"]) - found_it = True - self.assertEqual(x.untyped_storage().data_ptr(), b["address"]) - self.assertTrue(found_it) + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none") - if not IS_WINDOWS: - with tempfile.NamedTemporaryFile() as f: - torch.cuda.memory._save_segment_usage(f.name) - with open(f.name) as f2: - self.assertTrue("test_cuda.py" in f2.read()) - del unused - del x - torch.cuda.empty_cache() - ss = torch.cuda.memory._snapshot() - self.assertTrue( - ss["device_traces"][0][-1]["action"] - in ("segment_free", "segment_unmap") + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings( + "pinned_use_cuda_host_register:none" ) - finally: - torch.cuda.memory._record_memory_history(None) + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings( + "pinned_num_register_threads:none" + ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") - def test_direct_traceback(self): - from torch._C._profiler import gather_traceback, symbolize_tracebacks + with self.assertRaises(RuntimeError): + torch.cuda.memory._set_allocator_settings( + "pinned_num_register_threads:1024" + ) - c = gather_traceback(True, True, True) - (r,) = symbolize_tracebacks([c]) - r = str(r) - self.assertTrue("test_cuda.py" in r) - self.assertTrue("unwind" in r) + @parametrize("max_split_size_mb_setting", [False, True]) + def test_raises_oom(self, max_split_size_mb_setting): + if max_split_size_mb_setting: + # CudaCachingAllocator does early return when searching available blocks + # if max_split_size_mb is not set + # Setting this triggers more parts of the code + torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024") + torch.cuda.memory.empty_cache() + with self.assertRaises(torch.cuda.OutOfMemoryError): + torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") + @unittest.skipIf( + not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux" + ) @unittest.skipIf( TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_snapshot_with_cpp(self): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history("state", stacks="all") - x = torch.rand(311, 411, device="cuda") - - ss = torch.cuda.memory._snapshot()["segments"] - found_it = False - for seg in ss: - for b in seg["blocks"]: - if b["requested_size"] == 311 * 411 * 4: - self.assertTrue("::rand" in str(b["frames"])) - found_it = True - self.assertTrue(found_it) + def test_cpp_memory_snapshot_pickle(self): + from torch.utils.cpp_extension import load_inline - finally: - torch.cuda.memory._record_memory_history(None) + source = """ + #include + py::object do_snapshot() { + std::string data = torch::cuda::_memory_snapshot_pickled(); + return py::bytes(data); + } + void record(bool e, bool ctx) { + torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx); + } + """ + m = load_inline( + name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"] + ) + for ctx in (False, True): + try: + m.record(True, ctx) - @skipIfRocm - def test_memory_profiler_viz(self): - with torch.profiler.profile( - with_stack=True, profile_memory=True, record_shapes=True - ) as prof: - x = torch.rand(128, 128, device="cuda") - x * x + x * x - plot = profile_plot(prof) - plot = json.dumps(_profile_to_snapshot(prof)) - self.assertTrue("test_cuda.py" in plot) - self.assertTrue("test_memory_profiler_viz" in plot) - self.assertTrue("category" in plot) + @torch.jit.script + def the_script_fn(): + return torch.rand(311, 411, device="cuda") - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_cycles(self): - fired = False + def run(): + t = the_script_fn() + return pickle.loads(m.do_snapshot()) - def observer(html): - nonlocal fired - fired = True - self.assertTrue("torch.Tensor" in html) - self.assertTrue("test_cuda" in html) - self.assertTrue("cell_contents" in html) + mem = run() + found = False + for s in mem["segments"]: + for b in s["blocks"]: + if b["state"] == "active_allocated": + if b["requested_size"] == 311 * 411 * 4: + if ctx: + frame_text = str(b["frames"]) + # C++ frame + self.assertTrue("::rand" in frame_text) + # script frame + self.assertTrue("the_script_fn" in frame_text) + # python frame + self.assertTrue("case.py" in frame_text) + found = True + last_action = mem["device_traces"][0][-1] + self.assertTrue(last_action["action"] == "alloc") + self.assertTrue(last_action["size"] == 311 * 411 * 4) + self.assertTrue(found) + finally: + m.record(False, False) - disarm = observe_tensor_cycles(observer) + @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") + def test_notifies_oom(self): + x = False - def noop(): - pass + def cb(device, alloc, device_alloc, device_free): + nonlocal x + x = True - try: + torch._C._cuda_attach_out_of_memory_observer(cb) + with self.assertRaises(torch.cuda.OutOfMemoryError): + torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") + self.assertTrue(x) - def create(): - x = torch.empty(3, 4, device="cuda") + def test_allocator_fuzz(self): + # fuzz + state = random.getstate() + random.seed(123) + N = 10000 + try: + mem = [] + total = 0 + c = 0 - def foo(p): - if p: - return foo(not p) - else: - return x + def alloc(): + nonlocal total, c + b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4) + mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda"))) + c += 1 + total += b - return foo + def free(): + nonlocal total + idx = random.randrange(0, len(mem)) + v, x = mem.pop(idx) + assert torch.all(v == x) + total -= x.numel() - create() - gc.collect() - # the callback has to run outside of the collect - # call so it doesn't actual fire until the next - # method call after a gc.collect - noop() - self.assertTrue(fired) + choices = [alloc, free, torch.cuda.memory.empty_cache] + for i in range(N): + while total >= 1024 * 1024 * 1024 / (4 * 10): + free() + (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1]) + action() finally: - disarm() + random.setstate(state) - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots(self): - for context, stacks in ( - ("all", "all" if IS_LINUX else "python"), - ("all", "python"), - (None, "python"), - ): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history( - "all", context=context, stacks=stacks - ) + @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + def test_nvml_get_handler(self): + if not torch.version.hip: + self.assertTrue(torch.cuda._get_pynvml_handler() is not None) + else: + self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) - def run(): - x = torch.rand(128, 128, device="cuda") - x * x + x * x + @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + def test_temperature(self): + self.assertTrue(0 <= torch.cuda.temperature() <= 150) - run() - cpp = stacks == "all" - record_context = context is not None - ss = torch.cuda.memory._snapshot() + @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + def test_power_draw(self): + self.assertTrue(torch.cuda.power_draw() >= 0) - tplot = trace_plot(ss) - splot = segment_plot(ss) - text = json.dumps(ss) + @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + def test_clock_speed(self): + self.assertTrue(torch.cuda.clock_rate() >= 0) - self.assertTrue(record_context == ("test_memory_plots" in text)) - self.assertTrue(cpp == ("::rand" in text)) - self.assertTrue(str(128 * 128 * 4) in text) - finally: - torch.cuda.memory._record_memory_history(None) +MIN_BLOCK_SIZE = 512 +SMALL_SIZE = 1048576 +SMALL_BUFFER = 2097152 +LARGE_BUFFER = 20971520 - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots_free_stack(self): - for context in ["alloc", "all", "state"]: - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history(context=context) - x = None - def thealloc(): - nonlocal x - x = torch.rand(3, 4, device="cuda") +def get_cudagraph_segments(pool_id): + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] - def thefree(): - nonlocal x - del x - thealloc() - thefree() - ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue(("thefree" in ss) == (context == "all")) - self.assertTrue(("thealloc" in ss) == (context != "state")) - finally: - torch.cuda.memory._record_memory_history(None) +def get_all_cudagraph_segments(): + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)] - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots_history_context(self): - try: - torch.cuda.memory.empty_cache() - x = None - def should_capture1(): - nonlocal x - x = torch.rand(4, 4, device="cuda") +def cudagraphify(fn, inputs, pool=None): + if not TEST_CUDA_GRAPH: + raise unittest.SkipTest("cuda graph test is skipped") - def should_not_capture(): - nonlocal x - x = torch.rand(3, 4, device="cuda") + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + fn(*inputs) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() - def should_capture2(): - nonlocal x - x = torch.rand(4, 4, device="cuda") + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream, pool=pool): + static_outputs = fn(*inputs) - # Recording with context and python call stacks should capture the call stack. - torch.cuda.memory._record_memory_history(context="all", stacks="python") - should_capture1() - # Recording with context=None should not capture the call stack. - torch.cuda.memory._record_memory_history(context=None) - should_not_capture() - # Recording with context and python call stacks should capture the call stack. - torch.cuda.memory._record_memory_history(context="all", stacks="python") - should_capture2() + return graph, static_outputs - ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue("should_capture1" in ss) - self.assertTrue("should_not_capture" not in ss) - self.assertTrue("should_capture2" in ss) - finally: - torch.cuda.memory._record_memory_history(None) - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "cpp contexts are x86 linux only") - def test_memory_plots_free_segment_stack(self): - for context in ["alloc", "all", "state"]: - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history(context=context) - x = torch.rand(3, 4, device="cuda") - del x - torch.cuda.memory.empty_cache() - - ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue(("empty_cache" in ss) == (context == "all")) - finally: - torch.cuda.memory._record_memory_history(None) +def int8_cuda(size): + return torch.ones([size], device="cuda", dtype=torch.uint8) - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - def test_memory_snapshot_script(self): - try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history("state", stacks="python") - @torch.jit.script - def foo(): - return torch.rand(311, 411, device="cuda") +def live_blocks(pool_id): + blocks = 0 + seg = get_cudagraph_segments(pool_id) + for segment in get_cudagraph_segments(pool_id): + for block in segment["blocks"]: + blocks += block["state"] == "active_allocated" + return blocks - x = foo() - ss = torch.cuda.memory._snapshot()["segments"] - found_it = False - for seg in ss: - for b in seg["blocks"]: - if b["requested_size"] == 311 * 411 * 4: - self.assertTrue(b["frames"][0]["name"] == "foo") - found_it = True - self.assertTrue(found_it) +def tensor_metadata(x): + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset(), + } - finally: - torch.cuda.memory._record_memory_history(None) - def test_allocator_settings(self): - def power2_div(size, div_factor): - pow2 = 1 - while pow2 < size: - pow2 = pow2 * 2 - if pow2 == size: - return pow2 - step = pow2 / 2 / div_factor - ret = pow2 / 2 - while ret < size: - ret = ret + step - return ret +def reconstruct_from_tensor_metadata(metadata): + s = torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"]) + t.set_( + source=s, + storage_offset=metadata["storage_offset"], + size=metadata["size"], + stride=metadata["stride"], + ) + return t - torch.cuda.memory.empty_cache() - key_allocated = ( - "active_bytes.all.allocated" - if not TEST_CUDAMALLOCASYNC - else "allocated_bytes.all.current" - ) - key_requested = "requested_bytes.all.allocated" - nelems = 21 * 1024 * 1024 - nbytes = 4 * nelems # floats are 4 bytes +@unittest.skipIf(not TEST_CUDA or TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI") +@torch.testing._internal.common_utils.markDynamoStrictTest +class TestBlockStateAbsorption(TestCase): + @property + def expandable_segments(self): + return EXPANDABLE_SEGMENTS - nelems_big = 100 * 1024 * 1024 - nbytes_big = 4 * nelems_big # floats are 4 bytes + def checkCheckpointedBlock(self, before_block, after_block): + for field in ("size", "state"): + self.assertEqual(before_block[field], after_block[field]) - start_mem = torch.cuda.memory_stats()[key_allocated] - torch.cuda.memory._set_allocator_settings("") - x = torch.rand(nelems, device="cuda") + def checkCheckpointedState(self, before_segments, after_segments): + # after may contain additional segments, but all of the segments in before + # should be exactly equivalent to after + after_ptr_to_segment = { + segment["address"]: segment for segment in after_segments + } - # test roundup_power2_divisions single value syntax - reg_mem = torch.cuda.memory_stats()[key_allocated] - start_requested = torch.cuda.memory_stats()[key_requested] - torch.cuda.memory._set_allocator_settings("roundup_power2_divisions:4") - y = torch.rand(nelems, device="cuda") + for before_segment in before_segments: + self.assertTrue(before_segment["address"] in after_ptr_to_segment) + after_segment = after_ptr_to_segment[before_segment["address"]] - pow2_div4_mem = torch.cuda.memory_stats()[key_allocated] - current_requested = torch.cuda.memory_stats()[key_requested] + for field in ( + "device", + "total_size", + "allocated_size", + "active_size", + "segment_type", + "segment_pool_id", + ): + self.assertEqual(before_segment[field], after_segment[field]) - self.assertTrue(reg_mem - start_mem == nbytes) - if not TEST_CUDAMALLOCASYNC: - # not supported with the cudaMallocAsync backend - self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4)) - self.assertTrue(current_requested - start_requested == nbytes) + self.assertEqual( + len(before_segment["blocks"]), len(after_segment["blocks"]) + ) + for before_block, after_block in zip( + before_segment["blocks"], after_segment["blocks"] + ): + self.checkCheckpointedBlock(before_block, after_block) - torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5") - torch.cuda.memory._set_allocator_settings( - "garbage_collection_threshold:0.5,max_split_size_mb:40" + @staticmethod + def setCheckpointPoolState( + device, state, stale_storages_ptr, storages_deleters=None + ): + stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr] + storages_deleters = ( + [] + if not storages_deleters + else [t.untyped_storage()._cdata for t in storages_deleters] + ) + torch._C._cuda_setCheckpointPoolState( + device, state, stale_storages_ptr, storages_deleters ) - # should have reset the power2 divisions now - torch.cuda.memory.empty_cache() - start_mem = torch.cuda.memory_stats()[key_allocated] - z = torch.rand(nelems, device="cuda") - reg_mem = torch.cuda.memory_stats()[key_allocated] - self.assertTrue(reg_mem - start_mem == nbytes) + def checkFunction(self, fn, inputs, pool=None): + graph, outputs = cudagraphify(fn, inputs, pool=pool) - # roundup_power2_divisions knob array syntax - torch.cuda.memory.empty_cache() - torch.cuda.memory._set_allocator_settings( - "garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,128:2,256:2,512:2,1024:1,>:1]" + pool_id = graph.pool() + device = outputs[0].device.index + + segments_before_checkpoint = get_cudagraph_segments(pool_id) + + state = torch._C._cuda_getCheckpointState(device, pool_id) + self.setCheckpointPoolState(device, state, [], []) + + self.checkCheckpointedState( + segments_before_checkpoint, get_cudagraph_segments(pool_id) ) - start_mem = torch.cuda.memory_stats()[key_allocated] - w = torch.rand(nelems, device="cuda") - pow2_div8_mem = torch.cuda.memory_stats()[key_allocated] - if not TEST_CUDAMALLOCASYNC: - # not supported with the cudaMallocAsync backend - self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8)) + def setUp(self): + super().setUp() + self.segment_length = len(get_all_cudagraph_segments()) - torch.cuda.memory.empty_cache() - start_mem = torch.cuda.memory_stats()[key_allocated] - v = torch.rand(nelems_big, device="cuda") + def tearDown(self): + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() - pow2_div2_mem = torch.cuda.memory_stats()[key_allocated] - if not TEST_CUDAMALLOCASYNC: - # not supported with the cudaMallocAsync backend - self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2)) + self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length) - torch.cuda.memory.empty_cache() - torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True") - start_mem = torch.cuda.memory_stats()[key_allocated] - w = torch.rand(nelems, device="cuda") - reg_mem = torch.cuda.memory_stats()[key_allocated] - self.assertTrue(reg_mem - start_mem == nbytes) + super().tearDown() - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings("foo:1,bar:2") + def test_simple(self): + def foo(): + x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8) + x = x + x + x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + y = int8_cuda(SMALL_SIZE) + x1 + z = int8_cuda(SMALL_SIZE) + return x, y, z - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings( - "garbage_collection_threshold:1.2" - ) + self.checkFunction(foo, []) - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings("max_split_size_mb:2") + def test_allocated_in_middle_of_segment(self): + def foo(): + small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] + return small_buffers[5].add_(2) - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none") + self.checkFunction(foo, []) - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings( - "pinned_use_cuda_host_register:none" - ) + def test_multiple_middle_allocations(self): + def foo(): + small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] + return small_buffers[5], small_buffers[8] - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings( - "pinned_num_register_threads:none" - ) + self.checkFunction(foo, []) - with self.assertRaises(RuntimeError): - torch.cuda.memory._set_allocator_settings( - "pinned_num_register_threads:1024" - ) + def test_middle_allocations_contiguous(self): + def foo(): + small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] + return small_buffers[5], small_buffers[6] - @parametrize("max_split_size_mb_setting", [False, True]) - def test_raises_oom(self, max_split_size_mb_setting): - if max_split_size_mb_setting: - # CudaCachingAllocator does early return when searching available blocks - # if max_split_size_mb is not set - # Setting this triggers more parts of the code - torch.cuda.memory._set_allocator_settings("max_split_size_mb:1024") - torch.cuda.memory.empty_cache() - with self.assertRaises(torch.cuda.OutOfMemoryError): - torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") + self.checkFunction(foo, []) - @unittest.skipIf( - not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux" - ) - @unittest.skipIf( - TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" - ) - def test_cpp_memory_snapshot_pickle(self): - from torch.utils.cpp_extension import load_inline + def test_additional_free_following_checkpoint(self): + def foo(): + return (int8_cuda(MIN_BLOCK_SIZE),) - source = """ - #include - py::object do_snapshot() { - std::string data = torch::cuda::_memory_snapshot_pickled(); - return py::bytes(data); - } - void record(bool e, bool ctx) { - torch::cuda::_record_memory_history(e, ctx, 10, ctx, ctx); - } - """ - m = load_inline( - name="snapshot", cpp_sources=[source], functions=["do_snapshot", "record"] - ) - for ctx in (False, True): - try: - m.record(True, ctx) + def foo2(): + return (int8_cuda(MIN_BLOCK_SIZE),) - @torch.jit.script - def the_script_fn(): - return torch.rand(311, 411, device="cuda") + graph, outputs = cudagraphify(foo, []) + pool_id = graph.pool() - def run(): - t = the_script_fn() - return pickle.loads(m.do_snapshot()) + segments_before_checkpoint = get_cudagraph_segments(pool_id) - mem = run() - found = False - for s in mem["segments"]: - for b in s["blocks"]: - if b["state"] == "active_allocated": - if b["requested_size"] == 311 * 411 * 4: - if ctx: - frame_text = str(b["frames"]) - # C++ frame - self.assertTrue("::rand" in frame_text) - # script frame - self.assertTrue("the_script_fn" in frame_text) - # python frame - self.assertTrue("case.py" in frame_text) - found = True - last_action = mem["device_traces"][0][-1] - self.assertTrue(last_action["action"] == "alloc") - self.assertTrue(last_action["size"] == 311 * 411 * 4) - self.assertTrue(found) - finally: - m.record(False, False) + state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) - @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled") - def test_notifies_oom(self): - x = False + graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) - def cb(device, alloc, device_alloc, device_free): - nonlocal x - x = True + self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, []) - torch._C._cuda_attach_out_of_memory_observer(cb) - with self.assertRaises(torch.cuda.OutOfMemoryError): - torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") - self.assertTrue(x) + del outputs2 - def test_allocator_fuzz(self): - # fuzz - state = random.getstate() - random.seed(123) - N = 10000 - try: - mem = [] - total = 0 - c = 0 + self.checkCheckpointedState( + segments_before_checkpoint, get_cudagraph_segments(pool_id) + ) - def alloc(): - nonlocal total, c - b = random.randrange(2 * 1024 * 1024 // 4, 20 * 1024 * 1024 // 4) - mem.append((c, torch.full((b,), c, dtype=torch.int32, device="cuda"))) - c += 1 - total += b + # TODO: re-enable + # def test_additional_free_error(self): + # def foo(): + # return int8_cuda(MIN_BLOCK_SIZE), - def free(): - nonlocal total - idx = random.randrange(0, len(mem)) - v, x = mem.pop(idx) - assert torch.all(v == x) - total -= x.numel() + # def foo2(): + # return int8_cuda(MIN_BLOCK_SIZE), - choices = [alloc, free, torch.cuda.memory.empty_cache] - for i in range(N): - while total >= 1024 * 1024 * 1024 / (4 * 10): - free() - (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1]) - action() - finally: - random.setstate(state) + # graph, outputs = cudagraphify(foo, []) + # pool_id = graph.pool() - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") - def test_nvml_get_handler(self): - if not torch.version.hip: - self.assertTrue(torch.cuda._get_pynvml_handler() is not None) - else: - self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) + # segments_before_checkpoint = get_cudagraph_segments(pool_id) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") - def test_temperature(self): - self.assertTrue(0 <= torch.cuda.temperature() <= 150) + # state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") - def test_power_draw(self): - self.assertTrue(torch.cuda.power_draw() >= 0) + # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) + # with self.assertRaisesRegex(Exception, "being manually freed must be passed"): + # self.setCheckpointPoolState(outputs[0].device.index, state, [], []) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") - def test_clock_speed(self): - self.assertTrue(torch.cuda.clock_rate() >= 0) + def test_tensor_dies_after_checkpoint(self): + def foo(): + return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE) + graph, outputs = cudagraphify(foo, []) + pool_id = graph.pool() + device = outputs[0].device.index -MIN_BLOCK_SIZE = 512 -SMALL_SIZE = 1048576 -SMALL_BUFFER = 2097152 -LARGE_BUFFER = 20971520 + segments_before_checkpoint = get_cudagraph_segments(pool_id) + state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) + output_data_ptrs = [output.data_ptr() for output in outputs] -def get_cudagraph_segments(pool_id): - segments = torch.cuda.memory_snapshot() - return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + del outputs + self.setCheckpointPoolState(device, state, [], []) -def get_all_cudagraph_segments(): - segments = torch.cuda.memory_snapshot() - return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)] + self.assertEqual(live_blocks(pool_id), 2) + torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0]) + self.assertEqual(live_blocks(pool_id), 1) + torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1]) + self.assertEqual(live_blocks(pool_id), 0) + def test_assigning_back_deleter_fns_to_tensor(self): + def foo(x): + return ( + int8_cuda(SMALL_BUFFER) + x, + int8_cuda(SMALL_BUFFER) + x, + int8_cuda(LARGE_BUFFER) + x, + ) -def cudagraphify(fn, inputs, pool=None): - if not TEST_CUDA_GRAPH: - raise unittest.SkipTest("cuda graph test is skipped") + inp = torch.tensor([1], device="cuda") + graph, outputs = cudagraphify(foo, [inp]) + pool_id = graph.pool() + graph.replay() - torch.cuda.synchronize() - stream = torch.cuda.Stream() - stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream): - fn(*inputs) - stream.synchronize() - torch.cuda.current_stream().wait_stream(stream) - torch.cuda.synchronize() + device = outputs[0].device.index - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream, pool=pool): - static_outputs = fn(*inputs) + for i in range(len(outputs)): + self.assertTrue(outputs[i].mean(dtype=torch.float) == 2) - return graph, static_outputs + state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) + output_ptrs = [output.untyped_storage().data_ptr() for output in outputs] + ten_metadata = [tensor_metadata(t) for t in outputs] -def int8_cuda(size): - return torch.ones([size], device="cuda", dtype=torch.uint8) + self.assertEqual(live_blocks(pool_id), 3) + del outputs -def live_blocks(pool_id): - blocks = 0 - seg = get_cudagraph_segments(pool_id) - for segment in get_cudagraph_segments(pool_id): - for block in segment["blocks"]: - blocks += block["state"] == "active_allocated" - return blocks + self.assertEqual(live_blocks(pool_id), 0) + reconstructed_tensors = [ + reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata + ] -def tensor_metadata(x): - return { - "nbytes": x.untyped_storage().nbytes(), - "data_ptr": x.untyped_storage().data_ptr(), - "size": x.shape, - "stride": x.stride(), - "dtype": x.dtype, - "device": x.device, - "storage_offset": x.storage_offset(), - } + for i in range(len(reconstructed_tensors)): + self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2) + inp.add_(1) + graph.replay() -def reconstruct_from_tensor_metadata(metadata): - s = torch._C._construct_storage_from_data_pointer( - metadata["data_ptr"], metadata["device"], metadata["nbytes"] - ) - t = torch.empty([0], device=metadata["device"], dtype=metadata["dtype"]) - t.set_( - source=s, - storage_offset=metadata["storage_offset"], - size=metadata["size"], - stride=metadata["stride"], - ) - return t + for i in range(len(reconstructed_tensors)): + self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3) + self.setCheckpointPoolState( + device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]] + ) -@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "NYI") -@torch.testing._internal.common_utils.markDynamoStrictTest -class TestBlockStateAbsorption(TestCase): - @property - def expandable_segments(self): - return EXPANDABLE_SEGMENTS + self.assertEqual(live_blocks(pool_id), 3) - def checkCheckpointedBlock(self, before_block, after_block): - for field in ("size", "state"): - self.assertEqual(before_block[field], after_block[field]) + reconstructed_tensors[0] = None + self.assertEqual(live_blocks(pool_id), 2) - def checkCheckpointedState(self, before_segments, after_segments): - # after may contain additional segments, but all of the segments in before - # should be exactly equivalent to after - after_ptr_to_segment = { - segment["address"]: segment for segment in after_segments - } - - for before_segment in before_segments: - self.assertTrue(before_segment["address"] in after_ptr_to_segment) - after_segment = after_ptr_to_segment[before_segment["address"]] - - for field in ( - "device", - "total_size", - "allocated_size", - "active_size", - "segment_type", - "segment_pool_id", - ): - self.assertEqual(before_segment[field], after_segment[field]) - - self.assertEqual( - len(before_segment["blocks"]), len(after_segment["blocks"]) - ) - for before_block, after_block in zip( - before_segment["blocks"], after_segment["blocks"] - ): - self.checkCheckpointedBlock(before_block, after_block) - - @staticmethod - def setCheckpointPoolState( - device, state, stale_storages_ptr, storages_deleters=None - ): - stale_storages_ptr = [t.untyped_storage()._cdata for t in stale_storages_ptr] - storages_deleters = ( - [] - if not storages_deleters - else [t.untyped_storage()._cdata for t in storages_deleters] - ) - torch._C._cuda_setCheckpointPoolState( - device, state, stale_storages_ptr, storages_deleters - ) - - def checkFunction(self, fn, inputs, pool=None): - graph, outputs = cudagraphify(fn, inputs, pool=pool) - - pool_id = graph.pool() - device = outputs[0].device.index - - segments_before_checkpoint = get_cudagraph_segments(pool_id) - - state = torch._C._cuda_getCheckpointState(device, pool_id) - self.setCheckpointPoolState(device, state, [], []) - - self.checkCheckpointedState( - segments_before_checkpoint, get_cudagraph_segments(pool_id) - ) - - def setUp(self): - super().setUp() - self.segment_length = len(get_all_cudagraph_segments()) - - def tearDown(self): - torch.cuda.synchronize() - gc.collect() - torch.cuda.empty_cache() - - self.assertEqual(len(get_all_cudagraph_segments()), self.segment_length) - - super().tearDown() - - def test_simple(self): - def foo(): - x = torch.zeros([SMALL_SIZE * 8], device="cuda", dtype=torch.uint8) - x = x + x - x1 = int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) + int8_cuda(SMALL_SIZE) - y = int8_cuda(SMALL_SIZE) + x1 - z = int8_cuda(SMALL_SIZE) - return x, y, z - - self.checkFunction(foo, []) - - def test_allocated_in_middle_of_segment(self): - def foo(): - small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] - return small_buffers[5].add_(2) - - self.checkFunction(foo, []) - - def test_multiple_middle_allocations(self): - def foo(): - small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] - return small_buffers[5], small_buffers[8] - - self.checkFunction(foo, []) - - def test_middle_allocations_contiguous(self): - def foo(): - small_buffers = [int8_cuda(MIN_BLOCK_SIZE) for _ in range(11)] - return small_buffers[5], small_buffers[6] - - self.checkFunction(foo, []) - - def test_additional_free_following_checkpoint(self): - def foo(): - return (int8_cuda(MIN_BLOCK_SIZE),) - - def foo2(): - return (int8_cuda(MIN_BLOCK_SIZE),) - - graph, outputs = cudagraphify(foo, []) - pool_id = graph.pool() - - segments_before_checkpoint = get_cudagraph_segments(pool_id) - - state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) - - graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) - - self.setCheckpointPoolState(outputs[0].device.index, state, outputs2, []) - - del outputs2 - - self.checkCheckpointedState( - segments_before_checkpoint, get_cudagraph_segments(pool_id) - ) - - # TODO: re-enable - # def test_additional_free_error(self): - # def foo(): - # return int8_cuda(MIN_BLOCK_SIZE), - - # def foo2(): - # return int8_cuda(MIN_BLOCK_SIZE), - - # graph, outputs = cudagraphify(foo, []) - # pool_id = graph.pool() - - # segments_before_checkpoint = get_cudagraph_segments(pool_id) - - # state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) - - # graph2, outputs2 = cudagraphify(foo2, [], pool=graph.pool()) - # with self.assertRaisesRegex(Exception, "being manually freed must be passed"): - # self.setCheckpointPoolState(outputs[0].device.index, state, [], []) - - def test_tensor_dies_after_checkpoint(self): - def foo(): - return int8_cuda(MIN_BLOCK_SIZE), int8_cuda(MIN_BLOCK_SIZE) - - graph, outputs = cudagraphify(foo, []) - pool_id = graph.pool() - device = outputs[0].device.index - - segments_before_checkpoint = get_cudagraph_segments(pool_id) - state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) - - output_data_ptrs = [output.data_ptr() for output in outputs] - - del outputs - - self.setCheckpointPoolState(device, state, [], []) - - self.assertEqual(live_blocks(pool_id), 2) - torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[0]) - self.assertEqual(live_blocks(pool_id), 1) - torch._C._cuda_cudaCachingAllocator_raw_delete(output_data_ptrs[1]) - self.assertEqual(live_blocks(pool_id), 0) - - def test_assigning_back_deleter_fns_to_tensor(self): - def foo(x): - return ( - int8_cuda(SMALL_BUFFER) + x, - int8_cuda(SMALL_BUFFER) + x, - int8_cuda(LARGE_BUFFER) + x, - ) - - inp = torch.tensor([1], device="cuda") - graph, outputs = cudagraphify(foo, [inp]) - pool_id = graph.pool() - graph.replay() - - device = outputs[0].device.index - - for i in range(len(outputs)): - self.assertTrue(outputs[i].mean(dtype=torch.float) == 2) - - state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) - - output_ptrs = [output.untyped_storage().data_ptr() for output in outputs] - ten_metadata = [tensor_metadata(t) for t in outputs] - - self.assertEqual(live_blocks(pool_id), 3) - - del outputs - - self.assertEqual(live_blocks(pool_id), 0) - - reconstructed_tensors = [ - reconstruct_from_tensor_metadata(metadata) for metadata in ten_metadata - ] - - for i in range(len(reconstructed_tensors)): - self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2) - - inp.add_(1) - graph.replay() - - for i in range(len(reconstructed_tensors)): - self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3) - - self.setCheckpointPoolState( - device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]] - ) - - self.assertEqual(live_blocks(pool_id), 3) - - reconstructed_tensors[0] = None - self.assertEqual(live_blocks(pool_id), 2) - - reconstructed_tensors[1] = None - self.assertEqual(live_blocks(pool_id), 1) + reconstructed_tensors[1] = None + self.assertEqual(live_blocks(pool_id), 1) # should not change, we did not pass it in to swap data ptrs reconstructed_tensors[2] = None @@ -4985,6 +4318,7 @@ def test_no_triton_on_import(self): self.assertEqual(rc, "False", "Triton was imported when importing torch!") +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") class TestMemPool(TestCase): def test_mempool_id(self): pool1 = torch.cuda.graph_pool_handle() @@ -5007,10 +4341,25 @@ def test_mempool_with_allocator(self): dummy_allocator_source = """ #include + #include + #include + extern "C" { + C10_EXPORT int called_dummy_alloc = 0; + C10_EXPORT int called_dummy_free = 0; + // Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865 - C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { return nullptr; } - C10_EXPORT void dummy_free(void* ptr) { } + C10_EXPORT void* dummy_alloc(size_t size, int device, void* stream) { + called_dummy_alloc = 123; + void* ptr; + C10_CUDA_CHECK(cudaMallocManaged(&ptr, size)); + return ptr; + } + + C10_EXPORT void dummy_free(void* ptr, size_t size, int device, void* stream) { + called_dummy_free = 321; + C10_CUDA_CHECK(cudaFree(ptr)); + } } """ dummy_allocator_libname = "dummy_allocator" @@ -5020,6 +4369,7 @@ def test_mempool_with_allocator(self): is_python_module=False, keep_intermediates=False, verbose=True, + with_cuda=True, ) allocator = torch.cuda.memory.CUDAPluggableAllocator( dummy_allocator, @@ -5031,6 +4381,18 @@ def test_mempool_with_allocator(self): # pool should point to the same allocator as the one passed into it self.assertEqual(allocator.allocator(), pool.allocator) + # no allocations happened yet, so called_dummy_alloc should be 0 + alloc_lib = ctypes.CDLL(dummy_allocator) + called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") + self.assertEqual(called_dummy_alloc.value, 0) + + with torch.cuda.use_mem_pool(pool): + out = torch.randn(1, device="cuda") + + # called_dummy_alloc should be 123 if dummy_alloc was used to allocate + # out tensor + self.assertEqual(called_dummy_alloc.value, 123) + def test_mempool_context(self): active_pool = torch.cuda.MemPoolContext.active_pool() @@ -5082,46 +4444,216 @@ def create_mempool_and_make_active(): self.assertEqual(len(set(active_pool_ids)), 4) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +@torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaOptims(TestCase): # These tests will be instantiate with instantiate_device_type_tests # to apply the new OptimizerInfo structure. - @onlyNativeDeviceTypes + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs" + ) @optims( - [optim for optim in optim_db if "fused" in optim.supported_impls], + [optim for optim in optim_db if optim.has_capturable_arg], dtypes=[torch.float32], ) - def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): - device = device.split(":")[0] - if device not in optim_info.supports_fused_on: - self.skipTest( - f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" - ) - optim_inputs = optim_info.optim_inputs_func(device=device) + def test_graph_optims(self, device, dtype, optim_info): optim_cls = optim_info.optim_cls - for optim_input in optim_inputs: - for _separate_unscale in (True, False): - kwargs = optim_input.kwargs - kwargs["fused"] = True - torch.manual_seed(20) - ( - mod_control, - mod_scaling, - opt_control, - opt_scaling, - data, - loss_fn, - _, - ) = _create_scaling_case( - optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device - ) - optimizer_kwargs = deepcopy(kwargs) - optimizer_kwargs["fused"] = False - if "lr" not in kwargs: - # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr - optimizer_kwargs["lr"] = 1.0 - opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs) - scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",) + ) + + steps_warmup = 3 + steps_train = 2 + + for optim_input in all_optim_inputs: + kwargs = optim_input.kwargs + + # lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam + # and torch.optim.adamw + kwargs["lr"] = 0.1 + + for actually_do_graphs in (True, False): + params = [ + torch.randn((i + 5, i + 5), device=device) for i in range(2) + ] + [torch.randn((), device=device)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + + # Control (capturable=False) + kwargs["capturable"] = False + + opt = optim_cls(params_control, **kwargs) + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads[i][j] + opt.step() + + # capturable=True + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads[i][j] + opt.step() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + opt.step() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads[i + steps_warmup][j] + opt.step() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + + @onlyCUDA + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + @optims( + [ + optim + for optim in optim_db + if "fused" in optim.supported_impls and "cuda" in optim.supports_fused_on + ], + dtypes=[torch.float32], + ) + def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info): + optim_cls = optim_info.optim_cls + + steps_warmup = 3 + steps_train = 2 + + optim_inputs = optim_info.optim_inputs_func(device=device) + + for optim_input in optim_inputs: + kwargs = optim_input.kwargs + kwargs["fused"] = True + + for actually_do_graphs in ( + (True, False) if optim_info.has_capturable_arg else (True,) + ): + params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)] + params_control = [p.clone().requires_grad_() for p in params] + params_graphed = [p.clone().requires_grad_() for p in params] + + # `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients. + grads = [ + [torch.randn_like(p) for p in params] + for _ in range(steps_warmup + steps_train) + ] + with torch.no_grad(): + grads_control = [[g.clone() for g in gs] for gs in grads] + grads_graphed = [[g.clone() for g in gs] for gs in grads] + + # Gradient Scaler + scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0) + with torch.no_grad(): + scaler_for_control._lazy_init_scale_growth_tracker(device) + + scaler_for_graphed = torch.cuda.amp.GradScaler() + scaler_for_graphed.load_state_dict(scaler_for_control.state_dict()) + with torch.no_grad(): + scaler_for_graphed._lazy_init_scale_growth_tracker(device) + + # Control (capturable=False) + if optim_info.has_capturable_arg: + kwargs["capturable"] = False + opt = optim_cls(params_control, **kwargs) + + for i in range(steps_warmup + steps_train): + for j, p in enumerate(params_control): + p.grad = grads_control[i][j] + scaler_for_control.step(opt) + scaler_for_control.update() + + # capturable=True + if optim_info.has_capturable_arg: + kwargs["capturable"] = True + opt = optim_cls(params_graphed, **kwargs) + + for i in range(steps_warmup): + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + if actually_do_graphs: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for i in range(steps_train): + if actually_do_graphs: + for j, p in enumerate(params_graphed): + p.grad.copy_(grads_graphed[i + steps_warmup][j]) + g.replay() + else: + # Passing capturable=True to the constructor and running without graphs should still be + # numerically correct, even if it's not ideal for performance. + for j, p in enumerate(params_graphed): + p.grad = grads_graphed[i + steps_warmup][j] + scaler_for_graphed.step(opt) + scaler_for_graphed.update() + + for p_control, p_graphed in zip(params_control, params_graphed): + self.assertEqual(p_control, p_graphed) + + @onlyNativeDeviceTypes + @optims( + [optim for optim in optim_db if "fused" in optim.supported_impls], + dtypes=[torch.float32], + ) + def test_grad_scaling_autocast_fused_optimizers(self, device, dtype, optim_info): + device = device.split(":")[0] + if device not in optim_info.supports_fused_on: + self.skipTest( + f"{device} is not supported for fused on {optim_info.optim_cls.__name__}" + ) + optim_inputs = optim_info.optim_inputs_func(device=device) + optim_cls = optim_info.optim_cls + for optim_input in optim_inputs: + for _separate_unscale in (True, False): + kwargs = optim_input.kwargs + kwargs["fused"] = True + torch.manual_seed(20) + ( + mod_control, + mod_scaling, + opt_control, + opt_scaling, + data, + loss_fn, + _, + ) = _create_scaling_case( + optimizer_ctor=optim_cls, optimizer_kwargs=kwargs, device=device + ) + optimizer_kwargs = deepcopy(kwargs) + optimizer_kwargs["fused"] = False + if "lr" not in kwargs: + # _create_scaling_case will set lr = 1.0 if optimizer_kwargs do not set lr + optimizer_kwargs["lr"] = 1.0 + opt_control = optim_cls(mod_control.parameters(), **optimizer_kwargs) + scaler_scaling = torch.amp.GradScaler(device, init_scale=128.0) scaler_control = torch.amp.GradScaler(device, init_scale=128.0) tracker = TensorTracker() for input, target in data: @@ -5255,6 +4787,7 @@ def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused): self.assertEqual(scaler._growth_tracker, growth_tracker) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") class TestGDS(TestCase): def _get_tmp_dir_fs_type(self): my_path = os.path.realpath("/tmp") @@ -5289,6 +4822,526 @@ def test_gds_read_write_tensors(self): torch.cuda.gds._gds_deregister_buffer(src2.untyped_storage()) +@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") +class TestCudaAutocast(TestAutocast): + def setUp(self): + super().setUp() + self.autocast_lists = AutocastTestLists(torch.device("cuda:0")) + + def tearDown(self): + del self.autocast_lists + super().tearDown() + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = op_with_args[2] # TEST_WITH_ROCM + if not skip_test: + self._run_autocast_outofplace( + op, args, torch.float16, device="cuda", amp_dtype=torch.float16 + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_bf16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op_with_args in self.autocast_lists.torch_fp16: + skip_test = False + op, args = op_with_args[0], op_with_args[1] + if len(op_with_args) == 3: + skip_test = op_with_args[2] # TEST_WITH_ROCM + should_error_from_cudnn = "cudnn" in op and ( + "TORCH_CUDNN_V8_API_DISABLED" in os.environ + and int(os.environ["TORCH_CUDNN_V8_API_DISABLED"]) + or torch.cuda.get_device_capability() < (8, 0) + ) + should_error_from_not_implemented = should_error_from_cudnn + if not skip_test: + if should_error_from_not_implemented: + with self.assertRaises( + RuntimeError, + msg=str(op) + " should not be supported for bfloat16!", + ): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda" + ) + else: + if torch.cuda.is_bf16_supported(): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda" + ) + else: + with self.assertRaisesRegex( + RuntimeError, "Device does not support bfloat16" + ): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda" + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_fp32(self): + for op_with_args in self.autocast_lists.torch_fp32: + op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + add_kwargs=maybe_kwargs, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_need_autocast_promote(self): + for op, args in self.autocast_lists.torch_need_autocast_promote: + self._run_autocast_outofplace( + op, args, torch.float32, device="cuda", amp_dtype=torch.float16 + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_torch_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + out_type=out_type, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_nn_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.nn_fp16: + self._run_autocast_outofplace( + op, + args, + torch.float16, + device="cuda", + module=torch._C._nn, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_nn_bf16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.nn_fp16: + if torch.cuda.is_bf16_supported(): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda", module=torch._C._nn + ) + else: + with self.assertRaisesRegex( + RuntimeError, "Device does not support bfloat16" + ): + self._run_autocast_outofplace( + op, args, torch.bfloat16, device="cuda", module=torch._C._nn + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_nn_fp32(self): + for op, args in self.autocast_lists.nn_fp32: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + module=torch._C._nn, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_linalg_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.linalg_fp16: + self._run_autocast_outofplace( + op, + args, + torch.float16, + device="cuda", + module=torch._C._linalg, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_methods_fp16(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + for op, args in self.autocast_lists.methods_fp16: + self._run_autocast_outofplace( + op, + args, + torch.float16, + device="cuda", + module=None, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_methods_fp32(self): + for op, args in self.autocast_lists.methods_fp32: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + module=None, + amp_dtype=torch.float16, + ) + + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_methods_expect_builtin_promote(self): + for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote: + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="cuda", + module=None, + out_type=out_type, + amp_dtype=torch.float16, + ) + + def test_autocast_banned(self): + with torch.autocast("cuda"): + for op, args, module in self.autocast_lists.banned: + with self.assertRaises(RuntimeError): + getattr(module, op)(*args) + + def test_autocast_ignored_types(self): + with torch.autocast("cuda"): + for ignore_type in (torch.double, torch.int32): + a_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") + b_ignore = torch.ones((8, 8), dtype=ignore_type, device="cuda:0") + c_16 = torch.ones((8, 8), dtype=torch.float16, device="cuda:0") + + # Tests if CastPolicy::fp16 ops ignore double and int + # Currently, no ops belonging to this policy support integer inputs. + if ignore_type is torch.double: + with self.assertRaises(RuntimeError): + torch.mm(a_ignore, c_16) + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.mm(a_ignore, b_ignore).dtype + self.assertTrue( + torch.mm(a_ignore, b_ignore).dtype is type_no_autocast + ) + + # Tests if CastPolicy::fp32 ops ignore double and int + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.pow(a_ignore, 2.0).dtype + self.assertTrue(torch.pow(a_ignore, 2.0).dtype is type_no_autocast) + + # Tests if CastPolicy::fp32_set_opt_dtype ops ignore double and int + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.sum(a_ignore).dtype + self.assertTrue(torch.sum(a_ignore).dtype is type_no_autocast) + + # Tests if CastPolicy::fp32_append_dtype ops ignore double and int + # Currently, no ops belonging to this policy support integer inputs. + if ignore_type is torch.double: + with torch.autocast("cuda", enabled=False): + type_no_autocast = torch.norm(a_ignore).dtype + self.assertTrue(torch.norm(a_ignore).dtype is type_no_autocast) + + def test_autocast_custom_enabled(self): + class MyMM(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward(ctx, a, b): + self.assertTrue(a.dtype is torch.float32) + self.assertTrue(b.dtype is torch.float32) + self.assertTrue(torch.is_autocast_enabled()) + ctx.save_for_backward(a, b) + return a.mm(b) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad): + self.assertTrue(torch.is_autocast_enabled()) + a, b = ctx.saved_tensors + a_grad, b_grad = grad.mm(b.t()), a.t().mm(grad) + self.assertTrue(a_grad.dtype is dtype and b_grad.dtype is dtype) + return a_grad, b_grad + + mymm = MyMM.apply + + x = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) + y = torch.randn((8, 8), device="cuda", dtype=torch.float32, requires_grad=True) + + dtypes = (torch.float16, torch.bfloat16) if TEST_BF16 else (torch.float16,) + for dtype in dtypes: + with torch.cuda.amp.autocast(dtype=dtype): + output = mymm(x, y) + self.assertTrue(output.dtype is dtype) + loss = output.sum() + loss.backward() + + def test_autocast_custom_cast_inputs(self): + class MyMM(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) + def forward(ctx, a, container, expect_type): + b = container[1][0] + self.assertTrue(a.dtype is expect_type) + self.assertTrue(b.dtype is expect_type) + self.assertFalse(torch.is_autocast_enabled()) + ctx.save_for_backward(a, b) + return a.mm(b) + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, grad): + self.assertFalse(torch.is_autocast_enabled()) + a, b = ctx.saved_tensors + return grad.mm(b.t()), None, None + + mymm = MyMM.apply + + x = torch.randn((8, 8), device="cuda", dtype=torch.float16, requires_grad=True) + # Puts one input tensor in a nested container. y's contained Tensor won't receive a gradient, + # because torch.autograd.Function can't hand gradients back to non-Tensor forward arguments. + # Sets requires_grad=False explicitly so we don't lie about expecting a gradient. + y = ( + 0, + { + 0: torch.randn( + (8, 8), device="cuda", dtype=torch.float16, requires_grad=False + ) + }, + ) + + with torch.autocast("cuda"): + output = mymm(x, y, torch.float32) + self.assertTrue(output.dtype is torch.float32) + loss = output.sum() + loss.backward() + + # Tests if custom_fwd becomes a no-op when mymm runs outside an autocast-enabled region. + output = mymm(x, y, torch.float16) + self.assertTrue(output.dtype is torch.float16) + loss = output.sum() + loss.backward() + + def test_autocast_custom_deprecated_warning(self): + with warnings.catch_warnings(record=True) as w: + + class MyMM(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + self.assertFalse(torch.is_autocast_enabled()) + return x + y + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + _, _ = ctx.saved_tensors + self.assertFalse(torch.is_autocast_enabled()) + return grad, grad + + self.assertRegex( + str(w[0].message), r"`torch.cuda.amp.custom_fwd\(args...\)` is deprecated." + ) + self.assertRegex( + str(w[1].message), r"`torch.cuda.amp.custom_bwd\(args...\)` is deprecated." + ) + + mymm = MyMM.apply + x = torch.randn(3, 3, requires_grad=True) + y = torch.randn(3, 3, requires_grad=True) + with torch.amp.autocast("cuda"): + output = mymm(x, y) + loss = output.sum() + loss.backward() + + def test_autocast_cat_jit(self): + # Reported at https://github.com/pytorch/pytorch/issues/38958 + + class Model(torch.nn.Module): + def forward(self): + a = torch.randn(1) + b = torch.randn(1) + c = torch.cat((a, b), 0) + d = torch.stack([c, c], 0) + return d + + # The JIT here doesn't really matter, we just need to call + # cat via the boxed API + model = Model() + model_jit_script = torch.jit.script(model) + + with torch.autocast("cuda", enabled=True): + model() + model_jit_script() + + # cudnn RNNs require special backend handling (weights are cast to FP16 and reflattened) + # so they get a dedicated test. + # Despite the large number of RNN cases it tries, the test takes < 15 seconds on a Titan V (similar to V100). + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + def test_autocast_rnn(self): + with torch.backends.cudnn.flags(enabled=True, deterministic=True): + # seq, batch, features, hidden size + clses = ("RNN", "GRU", "LSTM") + T, B, F, H = 3, 4, 5, 6 + dtypes = (torch.float16, torch.float32) + input_layouts = ("seq_first", "batch_first", "packed") + + for ( + cls, + num_layers, + bias, + input_layout, + bidirectional, + try_nonpreflattened_weights, + input_dtype, + hidden_dtype, + weight_dtype, + ) in product( + clses, + (1, 2), + (True, False), + input_layouts, + (True, False), + (True, False), + dtypes, + dtypes, + dtypes, + ): + if input_layout == "seq_first": + batch_first = False + x = torch.randn((T, B, F), device="cuda", dtype=input_dtype) + elif input_layout == "batch_first": + batch_first = True + x = torch.randn((B, T, F), device="cuda", dtype=input_dtype) + elif input_layout == "packed": + batch_first = False + x = torch.nn.utils.rnn.pack_padded_sequence( + torch.randn((T, B, F), device="cuda", dtype=input_dtype), + lengths=(3, 2, 1, 3), + enforce_sorted=False, + ) + + rnn = ( + getattr(torch.nn, cls)( + F, + H, + num_layers=num_layers, + bidirectional=bidirectional, + bias=bias, + batch_first=batch_first, + ) + .cuda() + .to(dtype=weight_dtype) + ) + + if try_nonpreflattened_weights: + for p in rnn.parameters(): + with torch.no_grad(): + p.set_(p.clone()) + + h = torch.randn( + (num_layers * (2 if bidirectional else 1), B, H), + device="cuda", + dtype=hidden_dtype, + ) + if cls == "LSTM": + c = torch.randn( + (num_layers * (2 if bidirectional else 1), B, H), + device="cuda", + dtype=hidden_dtype, + ) + h = (h, c) + + with torch.autocast("cuda"): + out, h_out = rnn(x, h) + out = out.data if input_layout == "packed" else out + self.assertEqual(out.dtype, torch.float16) + # Autocast wrapper requires at::_cudnn_rnn is autograd-exposed. This check can't guarantee + # at::_cudnn_rnn is autograd-exposed, but if it fires, it indicates some funny business has + # occurred and we should double check that at::_cudnn_rnn remains autograd-exposed. + self.assertEqual( + out.grad_fn.name(), + "MiopenRnnBackward0" if torch.version.hip else "CudnnRnnBackward0", + ) + out.sum().backward() + grads = [p.grad.clone() for p in rnn.parameters()] + + rnn.zero_grad() + + if cls == "LSTM": + out_control, h_out_control = rnn.to(dtype=torch.float16)( + x.half(), (h[0].half(), h[1].half()) + ) + else: + out_control, h_out_control = rnn.to(dtype=torch.float16)( + x.half(), h.half() + ) + out_control = ( + out_control.data if input_layout == "packed" else out_control + ) + out_control.sum().backward() + grads_control = [p.grad.clone() for p in rnn.parameters()] + + # Compares with default tolerances, even for FP16 execution. Barring nondeterminism, + # autocast and control results should be bitwise identical. + self.assertEqual(out, out_control) + + if cls == "LSTM": + self.assertTrue( + h_out[0].dtype is torch.float16 + and h_out[1].dtype is torch.float16 + ) + self.assertEqual(h_out[0], h_out_control[0]) + self.assertEqual(h_out[1], h_out_control[1]) + else: + self.assertEqual(h_out.dtype, torch.float16) + self.assertEqual(h_out, h_out_control) + for grad, grad_control in zip(grads, grads_control): + self.assertEqual(grad.half(), grad_control) + + def test_autocast_cache_leak(self): + # Reported at https://github.com/pytorch/pytorch/issues/48049 + # Test is used to check, if autocast recaches the same parameters + # when executed in a `torch.no_grad()` block. + + linear = torch.nn.Linear(10, 10).to("cuda") + data = torch.randn(1, 10, device="cuda") + + with torch.autocast("cuda"): + with torch.no_grad(): + out = linear(data) + first_iter_mem = torch.cuda.memory_allocated() + for _ in range(3): + out = linear(data) + self.assertTrue(first_iter_mem == torch.cuda.memory_allocated()) + + def test_autocast_checkpointing(self): + model = torch.nn.Sequential( + torch.nn.Linear(8, 8), torch.nn.Linear(8, 8), torch.nn.Linear(8, 8) + ).cuda() + input = torch.rand( + (8, 8), device="cuda", dtype=torch.float16, requires_grad=True + ) + for reentrant in (True, False): + with torch.autocast("cuda"): + output = checkpoint_sequential(model, 2, input, use_reentrant=reentrant) + self.assertTrue(output.requires_grad) + self.assertTrue(output.dtype is torch.float16) + output.sum().backward() + + def test_cuda_autocast_deprecated_warning(self): + with self.assertWarnsRegex( + FutureWarning, + r"`torch.cuda.amp.autocast\(args...\)` is deprecated. Please use `torch.amp.autocast\('cuda', args...\)` instead.", + ): + with torch.cuda.amp.autocast(): + _ = torch.ones(10) + + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_device_type_tests(TestCudaOptims, globals()) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index f2a51a03476b4d..816b640eec861e 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -20,7 +20,7 @@ from torch import Tensor from torch._custom_op.impl import CustomOp, infer_schema from torch._library.infer_schema import tuple_to_list -from torch._utils_internal import get_file_path_2 +from torch._utils_internal import get_file_path_2 # @manual from torch.testing._internal import custom_op_db from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_device_type import ( @@ -3428,6 +3428,26 @@ def vmap(info, in_dims, w, x=2, *, y=3, z): self.assertTrue(called) self.assertEqual(result, w * 2 * 3 * 42) + def test_layout_constraint_tags(self): + needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order + flexible_layout = torch._C.Tag.flexible_layout + # (tags, the result of the tag inference) + tests = [ + ({needs_fixed_stride_order}, needs_fixed_stride_order), + ({flexible_layout}, flexible_layout), + # If no tags are provided, then the following is the default + (set(), needs_fixed_stride_order), + # If multiple tags are provided, then we use the most constrained tag. + ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), + ] + from torch._inductor.lowering import get_layout_constraint_tag + + for tags, expected in tests: + with torch.library._scoped_library("mylib", "FRAGMENT") as m: + m.define("foobar(Tensor x) -> Tensor", tags=tags) + result = get_layout_constraint_tag(torch.ops.mylib.foobar.default) + self.assertEqual(result, expected) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap(self): for mode in ["function", "qualname", "opoverload", "c_opdef"]: diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 0cdba26a303cf0..6ece8afac855b9 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -40,6 +40,7 @@ TEST_WITH_ROCM, TEST_WITH_TSAN, TestCase, + xfailIfLinux, ) from torch.utils.data import ( _utils, @@ -1382,6 +1383,8 @@ def test_multiple_dataloaders(self): del loader1_it del loader2_it + # https://github.com/pytorch/pytorch/issues/128551 + @xfailIfLinux def test_segfault(self): p = ErrorTrackingProcess(target=_test_segfault) p.start() diff --git a/test/test_decomp.py b/test/test_decomp.py index b3ccf298516939..7ab3859454ff6b 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -10,7 +10,7 @@ import torch._inductor.decomposition import torch.autograd from torch import Tensor -from torch._decomp import core_aten_decompositions, decomposition_table +from torch._decomp import _is_cia_op, core_aten_decompositions, decomposition_table from torch._dispatch.python import enable_python_dispatcher from torch._ops import DispatchKey from torch.testing import make_tensor @@ -62,7 +62,7 @@ def overload_to_aten_name(op): core_decomposition_names = { overload_to_aten_name(k) for k in core_aten_decompositions() - if isinstance(k, torch._ops.OpOverload) + if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k) } _decomp_test_ops = [ op @@ -1179,24 +1179,6 @@ def test_weight_norm_interface(self, device): def test_sdpa(self, device, dtype, op): # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we # add support for float16 over there we should update this test as well. - - class ScaledDotProductAttention(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward( - self, query_layer, key_layer, value_layer, mask=None, is_causal=True - ): - attn_output = op( - query_layer, - key_layer, - value_layer, - attn_mask=mask, - dropout_p=0.0, - is_causal=is_causal, - ) - return attn_output - query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) @@ -1206,12 +1188,17 @@ def forward( for mask in masks: is_causal = mask is None - attention = ScaledDotProductAttention() decomposed_res = ( torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask ) ) + actual_res = decomposed_res[0] + # Output has form (N, H, L, E), but should be continuous on (L, N, H, E) + # in order for subsequent view(L * N, H * E) to be valid. + # So permute(2, 0, 1, 3) before checking that tensor is contiguous + self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous()) + eager_res = op( query_layer, key_layer, @@ -1221,9 +1208,7 @@ def forward( is_causal=is_causal, ) - self.assertTrue( - torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol) - ) + self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol)) instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 89d8686219eaed..66f6fdeb7237a5 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -33,6 +33,7 @@ StatelessSymbolicContext, statically_known_true, ) +from torch.testing._internal.common_dtype import all_types_and from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -191,7 +192,7 @@ def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None) ) -def create_symtype(cls, pytype, shape_env, val, duck=True): +def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs): from torch._dynamo.source import ConstantSource symbol = shape_env.create_symbol( @@ -199,13 +200,14 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"), dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC, constraint_dim=None, + **kwargs, ) return cls(SymNode(symbol, shape_env, pytype, hint=val)) # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True) -> SymInt: - return create_symtype(SymInt, int, shape_env, i, duck=duck) +def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt: + return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs) def create_symbool(shape_env, b: bool) -> SymBool: @@ -779,6 +781,13 @@ def test_guard_refine_range(self): self.assertTrue(statically_known_true(i0 >= 4)) self.assertTrue(statically_known_true(i0 > 3)) + def test_mul_int_oo_nan(self): + shape_env = ShapeEnv() + s0 = create_symint(shape_env, 5, duck=False) + s1 = create_symint(shape_env, 6, duck=False) + s2 = create_symint(shape_env, 5, duck=False) + bool(s0 * (s1 // s0) == s2) + def test_non_overlapping_and_dense(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) @@ -1060,6 +1069,33 @@ def test_ephemeral_source_unified_with_non_ephemeral_source(self): self.assertEqual(x.stride(), y.stride()) self.assertEqual(x.storage_offset(), y.storage_offset()) + def test_tensor_factory_with_symint(self): + args = list(range(0, 3)) + expected = torch.tensor(args) + + shape_env = ShapeEnv() + sym_args = [create_symint(shape_env, i) for i in args] + + # test tensor factories + for dt in all_types_and(torch.half, torch.bfloat16): + res = torch.tensor(sym_args, dtype=dt) + self.assertEqual(res, expected, exact_dtype=False) + + # test legacy tensor factories + legacy_ctors = [ + torch.Tensor, + torch.LongTensor, + torch.DoubleTensor, + torch.FloatTensor, + torch.IntTensor, + torch.ShortTensor, + torch.HalfTensor, + torch.ByteTensor, + ] + for Tensor in legacy_ctors: + res = Tensor(sym_args) + self.assertEqual(res, expected, exact_dtype=False) + @skipIfTorchDynamo( "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index f16dc1c0862744..bf756f7b30fcdc 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -23,6 +23,7 @@ from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor from torch._dynamo.testing import make_test_cls_with_patches, rand_strided from torch._guards import tracing, TracingContext +from torch._higher_order_ops.scan import scan from torch._subclasses.fake_tensor import ( DynamicOutputShapeException, extract_tensor_metadata, @@ -923,6 +924,21 @@ def f(x): self.assertIsInstance(r, FakeTensor) self.assertEqual(r.size(), [3]) + @parametrize("reverse", [False, True]) + def test_scan(self, reverse): + def add(x, y): + return x + y, x + y + + with torch._subclasses.fake_tensor.FakeTensorMode(): + x = torch.randn((3, 5, 7), device="cpu") + init = torch.randn((3, 1, 7), device="cpu") + r = scan(add, init, x, dim=1, reverse=reverse) + + self.assertIsInstance(r[0], FakeTensor) + self.assertIsInstance(r[1], FakeTensor) + self.assertEqual(r[0].size(), init.size()) + self.assertEqual(r[1].size(), x.size()) + instantiate_parametrized_tests(FakeTensorTest) diff --git a/test/test_file_check.py b/test/test_file_check.py new file mode 100644 index 00000000000000..6aea06536781e5 --- /dev/null +++ b/test/test_file_check.py @@ -0,0 +1,53 @@ +# Owner(s): ["module: unknown"] + +from torch.testing import FileCheck +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestFileCheck(TestCase): + def test_not_run(self): + stdout, stderr = self.run_process_no_exception( + """\ +from torch.testing import FileCheck +file_check = FileCheck().check("not run") +del file_check +""", + ) + FileCheck().check("You have not run this instance of FileCheck!").check_next( + "FileCheck checks:" + ).check_next("\tCHECK: not run").run(stdout) + + def test_all_python_api(self): + test_string = """ +check check_same +check_next +check_count +check_dag +check_source_highlighted +~~~~~~~~~~~~~~~~~~~~~~~~ +check_regex +""" + FileCheck().check("check").check_not("check_not").check_same( + "check_same" + ).check_next("check_next").check_count("check_count", 1).check_dag( + "check_dag" + ).check_source_highlighted("check_source_highlighted").check_regex( + r"check_.+" + ).run(test_string) + + FileCheck().run( + """ +# CHECK: check +# CHECK-NOT: check_not +# CHECK-SAME: check_same +# CHECK-NEXT: check_next +# CHECK-DAG: check_dag +# CHECK-SOURCE-HIGHLIGHTED: check_source_highlighted +# CHECK-REGEX: check_.+ + """, + test_string, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_foreach.py b/test/test_foreach.py index c3d9892ef15023..c91c88abcb4cbe 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -206,7 +206,9 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): _, _, func, ref = self._get_funcs(op) else: func, ref, _, _ = self._get_funcs(op) - for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous): + for sample in op.sample_inputs( + device, dtype, noncontiguous=noncontiguous, allow_higher_dtype_scalars=True + ): ref_kwargs = sample.kwargs # div promotes ints to floats, so we cannot go on the fastpath there div_slowpath = ( @@ -229,11 +231,7 @@ def test_parity(self, device, dtype, op, noncontiguous, inplace): **sample.kwargs, ) except Exception as e: - with ( - self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])) - if not (op.has_no_in_place or not op.supports_out) - else self.assertRaises(type(e)) - ): + with self.assertRaises(type(e)): ref([ref_input, *sample.ref_args], **ref_kwargs) else: expected = ref([ref_input, *sample.ref_args], **ref_kwargs) @@ -309,7 +307,12 @@ def clone(arg): scalar_self_arg_test_complete = False for i, sample in enumerate( - op.sample_inputs(device, dtype, noncontiguous=not is_fastpath) + op.sample_inputs( + device, + dtype, + noncontiguous=not is_fastpath, + allow_higher_dtype_scalars=True, + ) ): (rhs_arg,) = sample.args kwargs = {} or sample.kwargs @@ -351,7 +354,12 @@ def clone(arg): def test_pointwise_op_with_tensor_of_scalarlist_overload( self, device, dtype, op, is_fastpath ): - for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath): + for sample in op.sample_inputs( + device, + dtype, + noncontiguous=not is_fastpath, + allow_higher_dtype_scalars=True, + ): assert isinstance(sample.args, tuple) assert len(sample.args) == 2 inputs = [sample.input, *sample.args] @@ -537,6 +545,24 @@ def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype): # Regression test for https://github.com/pytorch/pytorch/issues/113156 torch._foreach_mul_(tensors, 1) + @onlyCUDA + @dtypes(torch.float32) + def test_foreach_check_stride_ignore_dims_of_one(self, device, dtype): + # default tensor stride is (9, 9, 3, 1). + tensor = torch.ones((2, 1, 3, 3), device=device, dtype=dtype) + strided_tensor = torch.ones( + (2, 1, 3, 3), device=device, dtype=dtype + ).as_strided((2, 1, 3, 3), (9, 1, 3, 1)) + left_inputs = [tensor, strided_tensor] + right_inputs = [strided_tensor, tensor] + compare_result = tensor + strided_tensor + foreach_add_check_ = ForeachFuncWrapper(torch._foreach_add) + out = foreach_add_check_( + (left_inputs, right_inputs), is_cuda=True, expect_fastpath=True + ) + for res in out: + self.assertEqual(res, compare_result) + @ops( filter(lambda op: op.supports_out, foreach_binary_op_db), dtypes=OpDTypes.supported, @@ -820,7 +846,14 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): method, ref, inplace_method, ref_inplace = self._get_funcs(op) # tensors: ['cuda', 'cpu] tensors = next( - iter(op.sample_inputs(device, dtype, num_input_tensors=[2])) + iter( + op.sample_inputs( + device, + dtype, + num_input_tensors=[2], + allow_higher_dtype_scalars=True, + ) + ) ).input tensors[1] = tensors[1].to("cpu") if not op.supports_out: @@ -848,10 +881,26 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): @ops(filter(lambda op: op.supports_out, foreach_binary_op_db)) def test_binary_op_tensors_on_different_devices(self, device, dtype, op): _cuda_tensors = next( - iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True)) + iter( + op.sample_inputs( + device, + dtype, + num_input_tensors=[2], + same_size=True, + allow_higher_dtype_scalars=True, + ) + ) ).input _cpu_tensors = next( - iter(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True)) + iter( + op.sample_inputs( + "cpu", + dtype, + num_input_tensors=[2], + same_size=True, + allow_higher_dtype_scalars=True, + ) + ) ).input tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors)) @@ -881,10 +930,24 @@ def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op): # tensors3: ['cuda', 'cpu] # first tensorlist is zero-size when float32 _cuda_tensors = list( - op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True) + op.sample_inputs( + device, + dtype, + num_input_tensors=[3], + same_size=True, + allow_higher_dtype_scalars=True, + ) )[int(dtype == torch.float32)].input _cpu_tensors = next( - iter(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True)) + iter( + op.sample_inputs( + "cpu", + dtype, + num_input_tensors=[3], + same_size=True, + allow_higher_dtype_scalars=True, + ) + ) ).input tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors)) @@ -1224,7 +1287,9 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): foreach_copy_ = op.inplace_variant copy_ = op.ref_inplace for non_blocking in (False, True): - for sample in op.sample_inputs(device, dtype, noncontiguous=False): + for sample in op.sample_inputs( + device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True + ): with torch.no_grad(): ref_input = [t.clone().detach() for t in sample.input] foreach_copy_(sample.input, sample.args[0], non_blocking) @@ -1244,7 +1309,9 @@ def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op): def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): # check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_ foreach_copy_ = ForeachFuncWrapper(op.inplace_variant) - for sample in op.sample_inputs(device, dtype, noncontiguous=False): + for sample in op.sample_inputs( + device, dtype, noncontiguous=False, allow_higher_dtype_scalars=True + ): for src_dtype in floating_types_and(torch.half, torch.bfloat16): if src_dtype == dtype: continue @@ -1253,13 +1320,12 @@ def test_foreach_copy_with_multi_dtypes(self, device, dtype, op): out = foreach_copy_( (self_tensors, src_tensors), is_cuda=True, expect_fastpath=True ) - self.assertEqual( - out, - [ - torch.empty_like(t).copy_(s) - for t, s in zip(self_tensors, src_tensors) - ], - ) + ref_out = [ + torch.empty_like(t).copy_(s) + for t, s in zip(self_tensors, src_tensors) + ] + for t, ref_t in zip(out, ref_out): + self.assertTrue(torch.equal(t, ref_t)) # Test reverse-mode & forward-mode AD if supported. @onlyCUDA @@ -1311,6 +1377,7 @@ def test_autodiff(self, device, dtype, op, inplace): dtype, requires_grad=True, num_input_tensors=[5], + allow_higher_dtype_scalars=True, **value_range, ): # Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])` @@ -1424,10 +1491,10 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): return False, "In-place abs is not supported for complex tensors." if op.name == "_foreach_sub" and ( ( - isinstance(sample.args[0], list) - and any(isinstance(a, bool) for a in sample.args[0]) + isinstance(sample.args[-1], list) + and any(isinstance(a, bool) for a in sample.args[-1]) ) - or isinstance(sample.args[0], bool) + or isinstance(sample.args[-1], bool) ): return False, _BOOL_SUB_ERR_MSG if op.name == "_foreach_norm" and (not is_inplace): @@ -1438,10 +1505,10 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): ) rhs_arg_has_complex_number = sample.args and ( ( - isinstance(sample.args[0], list) - and any(isinstance(a, complex) for a in sample.args[0]) + isinstance(sample.args[-1], list) + and any(isinstance(a, complex) for a in sample.args[-1]) ) - or (isinstance(sample.args[0], complex)) + or (isinstance(sample.args[-1], complex)) ) if rhs_arg_has_complex_number and dtype == torch.float64: if op.name in ( @@ -1451,6 +1518,8 @@ def check_autodiff_sample(op, sample, dtype, is_inplace): "_foreach_minimum", ): return False, "clamp is not supported for complex types" + if op.name == "_foreach_lerp" and is_inplace: + return False, "value cannot be converted to type double without overflow" if not is_inplace: return False, "" else: diff --git a/test/test_jit.py b/test/test_jit.py index 5b3a40973acac3..c3af8bc9f48fcf 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -15636,7 +15636,7 @@ def fn(string): def test_unicode_comments(self): @torch.jit.script def test(self, a): - # shrug + # 🤷🤷🤷🤷 return torch.nn.functional.relu(a) def test_get_set_state_with_tensors(self): diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index bf42cf086376ec..580f1301d40605 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -66,6 +66,18 @@ def _compare_mts(mt1, mt2, rtol=1e-05, atol=1e-08): if not _tensors_match(a, b, exact=False, rtol=rtol, atol=atol): raise ValueError("The data in MaskedTensor mt1 and MaskedTensor mt2 do not match") +def _compare_forward_backward(data, mask, fn): + mt = masked_tensor(data, mask, requires_grad=True) + masked_res = fn(mt) + masked_res.sum().backward() + + t = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() + tensor_res = fn(t) + tensor_res.sum().backward() + + _compare_mt_t(masked_res, tensor_res) + _compare_mt_t(mt.grad, t.grad, atol=1e-06) + def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) @@ -164,15 +176,8 @@ def test_softmax(self, device): ], device=device ) - mt = masked_tensor(data, mask, requires_grad=True) - masked_res = torch.softmax(mt, -1) - masked_res.sum().backward() - xinf = data.masked_fill(~mask, float("-inf")).detach().clone().requires_grad_() - tensor_res = torch.softmax(xinf, -1) - tensor_res.sum().backward() - _compare_mt_t(masked_res, tensor_res) - _compare_mt_t(mt.grad, xinf.grad, atol=1e-06) + _compare_forward_backward(data, mask, lambda t: torch.softmax(t, -1)) def test_where(self, device): data = torch.tensor([-10.0, -5, 0, 5, 10, 50, 60, 70, 80, 90, 100], device=device) @@ -192,6 +197,35 @@ def test_where(self, device): _compare_mt_t(mx.grad, x.grad) _compare_mt_t(my.grad, y.grad) + def test_unfold(self, device): + data = torch.rand(5, 5, device=device) + mask = torch.rand(5, 5, device=device) > 0.5 + _compare_forward_backward(data, mask, lambda t: t.unfold(1, 2, 2)) + + def test_nn_unfold(self, device): + data = torch.rand(2, 5, 3, 4, device=device) + mask = torch.rand(2, 5, 3, 4, device=device) > 0.5 + _compare_forward_backward(data, mask, lambda t: torch.nn.functional.unfold(t, kernel_size=(2, 3))) + + def test_stack(self, device): + masked_tensors = [ + masked_tensor( + torch.rand(2, 5, 3, 4, device=device), + torch.rand(2, 5, 3, 4, device=device) > 0.5, + requires_grad=True, + ) for _ in range(3) + ] + + data_tensors = [mt.get_data().detach().clone().requires_grad_() for mt in masked_tensors] + masked_res = torch.stack(masked_tensors) + tensor_res = torch.stack(data_tensors) + + masked_res.sum().backward() + tensor_res.sum().backward() + _compare_mt_t(masked_res, tensor_res) + for mt, t in zip(masked_tensors, data_tensors): + _compare_mt_t(mt.grad, t.grad, atol=1e-06) + def test_to_sparse(self, device): for sample in _generate_sample_data(device=device): data = sample.input diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 70235665cb6fef..dc82e517ca30bd 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -560,6 +560,7 @@ def test_float8_scale_fast_accum(self, device) -> None: self.assertEqual(out_fp8, out_fp8_s) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific") @skipIfRocm() @parametrize("use_fast_accum", [True, False]) def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index b61e50fb4f6227..5f192d7c349dfe 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -1469,24 +1469,30 @@ def _lstm_params_list(self): params_list = list(params_dict.values()) return params_list - def _cast_dtype(self, input, bf16): - if bf16: + def _cast_dtype(self, input, dtype): + if dtype == torch.bfloat16: input = input.to(torch.bfloat16) + elif dtype == torch.half: + input = input.to(torch.half) return input - @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path") def test_lstm(self): seed = 2023 torch.manual_seed(seed) params_list = self._lstm_params_list() for dtype in types: - bf16 = True if dtype == torch.bfloat16 and torch.ops.mkldnn._is_mkldnn_bf16_supported() else False + bf16 = dtype == torch.bfloat16 + fp16 = dtype == torch.half rtol = 1.3e-6 atol = 1e-5 + if bf16: rtol = 0.02 atol = 0.02 + if fp16: + rtol = 1e-3 + atol = 1e-3 for input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len, training \ in itertools.product(*params_list): num_directions = 2 if bidirectional else 1 @@ -1496,7 +1502,9 @@ def test_lstm(self): input = torch.randn(seq_len, batch_size, input_size, dtype=torch.float32) h = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) c = torch.randn(num_layers * num_directions, batch_size, hidden_size, dtype=torch.float32) - + if fp16: + # TODO add traing support when oneDNN support lstm FP16 training + training = False model = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, bias=bias, dropout=dropout, batch_first=batch_first).float() model.train() if training else model.eval() @@ -1510,15 +1518,25 @@ def test_lstm(self): model1 = copy.deepcopy(model) model2 = copy.deepcopy(model) - with torch.cpu.amp.autocast(enabled=bf16, dtype=torch.bfloat16), torch.no_grad() if not training else nullcontext(): + with torch.no_grad() if not training else nullcontext(): with torch.backends.mkldnn.flags(enabled=False): torch.manual_seed(seed) - output1, (hn1, cn1) = self._cast_dtype(model1, bf16)(self._cast_dtype(input1, bf16), - (self._cast_dtype(h1, bf16), - self._cast_dtype(c1, bf16))) + output1, (hn1, cn1) = self._cast_dtype(model1, dtype)( + self._cast_dtype(input1, dtype), + ( + self._cast_dtype(h1, dtype), + self._cast_dtype(c1, dtype), + ), + ) torch.manual_seed(seed) - output2, (hn2, cn2) = model2(input2, (h2, c2)) + output2, (hn2, cn2) = self._cast_dtype(model2, dtype)( + self._cast_dtype(input2, dtype), + ( + self._cast_dtype(h2, dtype), + self._cast_dtype(c2, dtype), + ), + ) self.assertEqual(output1, output2, rtol=rtol, atol=atol) self.assertEqual(hn1, hn2, rtol=rtol, atol=atol) self.assertEqual(cn1, cn2, rtol=rtol, atol=atol) @@ -1533,8 +1551,13 @@ def test_lstm(self): self.assertEqual(input1.grad, input2.grad, rtol=rtol, atol=atol) for name, para in model1.named_parameters(): - self.assertEqual(para, self._cast_dtype(getattr(model2, name), bf16)) - self.assertEqual(para.grad, self._cast_dtype(getattr(model2, name).grad, bf16), rtol=rtol, atol=atol) + self.assertEqual(para, getattr(model2, name)) + self.assertEqual( + para.grad, + getattr(model2, name).grad, + rtol=rtol, + atol=atol, + ) with torch.backends.mkldnn.flags(enabled=False): torch.manual_seed(seed) diff --git a/test/test_model_exports_to_core_aten.py b/test/test_model_exports_to_core_aten.py index bbf0a8ba3c04ac..aae14c28b8d6e0 100644 --- a/test/test_model_exports_to_core_aten.py +++ b/test/test_model_exports_to_core_aten.py @@ -27,7 +27,7 @@ def test_vit_aten_export(self): m = m.eval() input_shape = (1, 3, 224, 224) example_inputs = (torch.randn(input_shape),) - m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs)) + m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() m(*example_inputs) m = export.export(m, copy.deepcopy(example_inputs)) ops = _get_ops_list(m.graph_module) diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py index abbaaed4491a74..457b9648d73c95 100644 --- a/test/test_module_tracker.py +++ b/test/test_module_tracker.py @@ -3,7 +3,9 @@ from copy import copy import torch +from torch import nn from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo +from torch.utils.checkpoint import checkpoint from torch.utils.module_tracker import ModuleTracker @@ -14,7 +16,7 @@ def test_module_hierarchy(self): seen_fw = [] seen_bw = [] - class Foo(torch.nn.Module): + class Foo(nn.Module): def forward(self, x): x = x["a"].relu_() seen_fw.append((copy(tracker.parents), tracker.is_bw)) @@ -23,12 +25,12 @@ def forward(self, x): ) return {"a": torch.mm(x, x)} - class Mod(torch.nn.Module): + class Mod(nn.Module): def __init__(self) -> None: super().__init__() self.a = Foo() - self.b = torch.nn.ModuleDict({"nest": Foo()}) - self.c = torch.nn.ModuleList([Foo()]) + self.b = nn.ModuleDict({"nest": Foo()}) + self.c = nn.ModuleList([Foo()]) def forward(self, x): x = self.c[0](x) @@ -68,8 +70,36 @@ def forward(self, x): ], ) + def test_confused_hierarchy(self): + class MyMod(nn.Module): + def __init__(self): + super().__init__() + self.inner = nn.Linear(2, 2) + self.ran = False + + def forward(self, inp): + if not self.ran: + self.ran = True + return self(inp) + else: + self.ran = False + return self.inner(inp) + + mod = MyMod() + inp = torch.rand(1, 2, requires_grad=True) + + # Should not fail + with ModuleTracker() as tracker: + res = mod(inp) + res.sum().backward() + + # Should not fail + with ModuleTracker() as tracker: + res = checkpoint(lambda inp: mod(inp), inp) + res.sum().backward() + def test_bw_detection(self): - mod = torch.nn.Linear(2, 2) + mod = nn.Linear(2, 2) with ModuleTracker() as tracker: mod(torch.rand(2, requires_grad=True)).sum().backward() diff --git a/test/test_mps.py b/test/test_mps.py index f4aa05177a7929..a342fa3425853e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -344,6 +344,7 @@ def mps_ops_modifier(ops): 'tanh', 'tensor_split', 'transpose', + 'transpose_copy', 'T', 'unbind', 'unflatten', @@ -438,6 +439,7 @@ def mps_ops_modifier(ops): 'logical_not', 'logical_or', 'logical_xor', + 'logsumexp', 'long', 'masked_fill', 'masked.mean', @@ -445,6 +447,7 @@ def mps_ops_modifier(ops): 'masked.std', 'masked.sum', 'masked.var', + 'masked.logsumexp', 'matmul', 'mean', 'mm', @@ -832,9 +835,6 @@ def mps_ops_modifier(ops): # Unsupported dtypes # bmm is not supported for integral types 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - # Cannot convert a MPS Tensor to float64 dtype. The tensors - # input data is created with double in common_methods_invocations.py - 'nn.functional.batch_norm': [torch.float32], 'ones_like': None, 'zeros_like': None, @@ -1204,6 +1204,29 @@ def __exit__(self, exec_type, exec_value, traceback): raise RuntimeError(msg) +class TestAutocastMPS(TestCase): + + def test_matmul_autocast(self): + autocast_tensor_A = torch.rand((8, 8), device="mps") + autocast_tensor_B = torch.rand((8, 8), device="mps") + tensor_A = autocast_tensor_A.clone().detach() + tensor_B = autocast_tensor_B.clone().detach() + autocast_output_tensor = torch.empty(8, 8) + output_tensor = autocast_output_tensor.clone().detach() + + with torch.autocast(device_type="mps"): + autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B) + autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor) + + output_tensor = torch.mm(tensor_A, tensor_B) + output_tensor = torch.mm(tensor_A, output_tensor) + + self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16") + self.assertEqual(autocast_output_tensor, + output_tensor.to(torch.float16), + f"Autocast & non-autocast tensors did not match, \ + got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}") + # Expand TestCase class with Memory Leak Detection on MPS device class TestCaseMPS(TestCase): _do_mps_memory_leak_check = True @@ -1917,6 +1940,20 @@ def test_bmm(self): self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + @xfailIf(product_version < 15.0) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_large_bmm(self, dtype): + batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps') + batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps') + output_cpu = torch.bmm(batch1.cpu(), batch2.cpu()) + output_mps = torch.bmm(batch1, batch2) + + # Using the low precision comparison for FP16 + tol = 1e-2 if dtype == torch.float16 else None + self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol) + self.assertEqual(output_cpu.size(), output_mps.size()) + + def test_addr(self): A = torch.ones(5, 10).to("mps") B = torch.ones(5).to("mps") @@ -2696,6 +2733,12 @@ def test_ifft(self): # Expecting the inverted to yield the original signal self.assertEqual(ifft_result, signal) + # Regression test for https://github.com/pytorch/pytorch/issues/135223 + def test_fftfreq(self): + freq_cpu = torch.fft.fftfreq(10**4, device='cpu') + freq_mps = torch.fft.fftfreq(10**4, device='mps') + self.assertEqual(freq_cpu, freq_mps) + def test_instance_norm(self): def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): @@ -5685,6 +5728,10 @@ def helper(n, c, h, w, dtype=torch.float32): helper(2, 8, 4, 5, dtype=torch.int32) helper(2, 8, 4, 5, dtype=torch.int64) helper(2, 8, 4, 5, dtype=torch.bool) + # Regression test for https://github.com/pytorch/pytorch/issues/136132 + x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2) + self.assertEqual(x.numel(), 8) + self.assertEqual(x.max().item(), 30.0) # Test forward prod def test_prod(self): @@ -6529,6 +6576,18 @@ def helper(shape): helper((2, 8, 4, 5)) + def test_logsumexp(self): + def helper(shape): + cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + x = cpu_x.detach().clone().to('mps') + + log_result = torch.logsumexp(x, -1) + log_result_cpu = torch.logsumexp(cpu_x, -1) + + self.assertEqual(log_result, log_result_cpu) + + helper((2, 8, 4, 5)) + # Test concat forward def test_cat2(self): @@ -9363,18 +9422,37 @@ def _compare_tensors(self, y, ref): err = ((y - ref).abs() / denom).mean().item() self.assertLess(err, 0.01) - def _test_sdpa_no_mask(self, is_causal: bool, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128): + def _test_sdpa_no_mask( + self, + is_causal: bool, + dtype: torch.dtype, + L: int = 1, + S: int = 72, + NH: int = 32, + HS: int = 128, + requires_grad: bool = False + ): + torch.manual_seed(1729) with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): - q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps") + q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad) k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") + q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad) + k_cpu = k.cpu() + v_cpu = v.cpu() y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal) - y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=is_causal) + y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal) self._compare_tensors(y.cpu(), y_ref) + if requires_grad and torch.is_grad_enabled(): + y.sum().backward() + y_ref.sum().backward() + + self._compare_tensors(q.grad.cpu(), q_cpu.grad) + def test_sdpa_no_mask_no_causal_fp32(self): self._test_sdpa_no_mask(False, torch.float32) @@ -9396,6 +9474,12 @@ def test_sdpa_no_mask_causal_fp16_L7_S17(self): def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self): self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121) + def test_sdpa_no_mask_no_causal_fp32_grad(self): + self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) + + with torch.no_grad(): + self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) + def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128): torch.manual_seed(1729) causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps')) @@ -9424,7 +9508,7 @@ def test_sdpa_mask_fp16_L6(self): self._test_sdpa_mask(torch.float16, 6) def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): - self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121) + self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) class TestGatherScatter(TestCaseMPS): @@ -12136,6 +12220,13 @@ def req_grad(t): cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) + if ( + op.name == "nn.functional.pad" + and op.variant_test_name in ["replicate", "reflect"] + and dtype == torch.float16 + ): + atol = 1e-5 + rtol = 1.5e-3 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) diff --git a/test/test_multiprocessing_spawn.py b/test/test_multiprocessing_spawn.py index df29e07b2227b2..acad97827ec423 100644 --- a/test/test_multiprocessing_spawn.py +++ b/test/test_multiprocessing_spawn.py @@ -92,7 +92,7 @@ def _test_nested(i, pids_queue, nested_child_sleep, start_method): # Kill self. This should take down the child processes as well. os.kill(os.getpid(), signal.SIGTERM) -class _TestMultiProcessing(TestCase): +class _TestMultiProcessing: start_method = None def test_success(self): @@ -194,11 +194,10 @@ def _test_nested(self): self.assertLess(time.time() - start, nested_child_sleep / 2) time.sleep(0.1) - @unittest.skipIf( NO_MULTIPROCESSING_SPAWN, "Disabled for environments that don't support the spawn start method") -class SpawnTest(_TestMultiProcessing): +class SpawnTest(TestCase, _TestMultiProcessing): start_method = 'spawn' def test_exception_raises(self): @@ -222,7 +221,7 @@ def _test_process_exited(self): IS_WINDOWS, "Fork is only available on Unix", ) -class ForkTest(_TestMultiProcessing): +class ForkTest(TestCase, _TestMultiProcessing): start_method = 'fork' @@ -230,11 +229,7 @@ class ForkTest(_TestMultiProcessing): IS_WINDOWS, "Fork is only available on Unix", ) -class ForkServerTest(_TestMultiProcessing): - start_method = 'forkserver' - - -class _ParallelTest: +class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing): orig_paralell_env_val = None def setUp(self): @@ -250,38 +245,17 @@ def tearDown(self): os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val -@unittest.skipIf( - NO_MULTIPROCESSING_SPAWN, - "Disabled for environments that don't support the spawn start method") -class ParallelSpawnShouldFallbackAndWorkTest(SpawnTest, _ParallelTest): - pass - - -@unittest.skipIf( - IS_WINDOWS, - "Fork is only available on Unix", -) -class ParallelForkShouldFallbackAndWorkTest(ForkTest, _ParallelTest): - pass - - -@unittest.skipIf( - IS_WINDOWS, - "Fork is only available on Unix", -) -class ParallelForkServerShouldWorkTest(ForkServerTest, _ParallelTest): - pass - - @unittest.skipIf( IS_WINDOWS, "Fork is only available on Unix", ) class ParallelForkServerPerfTest(TestCase): + def test_forkserver_perf(self): + start_method = 'forkserver' expensive = Expensive() - nprocs = 6 + nprocs = 4 orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) # test the non parallel case @@ -289,7 +263,7 @@ def test_forkserver_perf(self): start = time.perf_counter() mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) elapsed = time.perf_counter() - start - # the time should be at least 6x the sleep time + # the elapsed time should be at least {nprocs}x the sleep time self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs) # test the parallel case @@ -297,9 +271,8 @@ def test_forkserver_perf(self): start = time.perf_counter() mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) elapsed = time.perf_counter() - start - - # the time should be at most 1x the sleep time + small overhead - self.assertLess(elapsed, Expensive.SLEEP_SECS + 10) + # the elapsed time should be less than {nprocs}x the sleep time + self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs) if orig_paralell_env_val is None: del os.environ[mp.ENV_VAR_PARALLEL_START] @@ -308,7 +281,7 @@ def test_forkserver_perf(self): class Expensive: - SLEEP_SECS = 10 + SLEEP_SECS = 5 # Simulate startup overhead such as large imports time.sleep(SLEEP_SECS) diff --git a/test/test_native_mha.py b/test/test_native_mha.py index 9a07485cb2e946..307115147852ff 100644 --- a/test/test_native_mha.py +++ b/test/test_native_mha.py @@ -276,8 +276,11 @@ def do_pad_all(tensors): @torch.no_grad() def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding, pad_all, fused): - if TEST_WITH_ROCM and use_nt: - self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") + if TEST_WITH_ROCM: + if use_nt: + self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") + if use_padding and not pad_all and fused: + self.skipTest("Large numerical errors on ROCM to investigate.") for need_weights in (False, not pad_all): with self.subTest(use_padding=use_padding, pad_all=pad_all, use_nt=use_nt, need_weights=need_weights, diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index bd4010d81083a1..c1b25313dbc24c 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -1,10 +1,12 @@ # Owner(s): ["module: nestedtensor"] +import ast import io import itertools import math import sys import unittest +from contextlib import nullcontext from functools import partial from typing import Optional, Tuple @@ -1025,9 +1027,14 @@ def test_embedding(self, device, layout): ) emb = torch.nn.Embedding(100, 8, device=device) y = emb(x) - ys = y.unbind() - for i, inp in enumerate(inputs): - self.assertEqual(emb(inp), ys[i]) + + @torch._dynamo.disable + def check(inputs, y): + ys = y.unbind() + for i, inp in enumerate(inputs): + self.assertEqual(emb(inp), ys[i]) + + check(inputs, y) @skipMeta @torch.inference_mode() @@ -4331,12 +4338,14 @@ def test_op_dim_reduce_ragged_idx_1_different_output_shape( out_expected = torch.cat( [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()] ) + if keepdim: + out_expected = out_expected.unsqueeze(reduce_dim[0]) self.assertFalse( out_actual.is_nested, f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", ) # output is a dense tensor - self.assertTrue(torch.allclose(out_actual, out_expected)) + self.assertEqual(out_actual, out_expected) @dtypes(torch.float32) @parametrize("requires_grad", [False, True]) @@ -4596,12 +4605,14 @@ def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape( for t in nt_transposed.unbind() ] ) + if keepdim: + out_expected = out_expected.unsqueeze(reduce_dim[0]) self.assertFalse( out_actual.is_nested, f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", ) # output is a dense tensor - self.assertTrue(torch.allclose(out_actual, out_expected, rtol=1e-4)) + self.assertEqual(out_actual, out_expected) @dtypes(torch.float32) @parametrize( @@ -5064,11 +5075,12 @@ def test_mean_dim_reduce_multiple_dims( ): """ Mean on NestedTensor fails when trying to reduce across multiple dimensions + only if the batch or ragged dims are included """ tensor_lists = self._get_example_tensor_lists( include_list_of_lists=False, include_requires_grad=components_require_grad ) - reduce_dims = ((0, 1), (2, 3), (2, 3, 4)) + reduce_dims = ((0, 1), (2, 3), (2, 3, 4), (0, 3), (1, 2)) for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): nt = torch.nested.nested_tensor( @@ -5080,10 +5092,20 @@ def test_mean_dim_reduce_multiple_dims( ) if nt.dim() > reduce_dim[-1]: - with self.assertRaisesRegex( - RuntimeError, - "not supported across multiple dimensions for NestedTensor", - ): + ragged_or_batch_included = ( + nt._ragged_idx in reduce_dim or 0 in reduce_dim + ) + + context = ( + self.assertRaisesRegex( + RuntimeError, + "not supported across multiple dimensions for NestedTensor", + ) + if ragged_or_batch_included + else nullcontext() + ) + + with context: out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim) @dtypes(torch.float32) @@ -5950,6 +5972,37 @@ def check_nt_equality(x, y): self.assertFalse(clone.is_contiguous()) check_nt_equality(detached, transposed) + def test_permute(self, device): + nt = random_nt_from_dims( + [2, None, 3, 5], device, torch.float32, layout=torch.jagged + ) + nt_shape = nt.shape + nt_inner_shape = nt.values().shape + with self.assertRaisesRegex( + ValueError, + r"permute\(\): number of dimensions in the tensor input \(4\) " + + r"does not match the length of the desired ordering of dimensions \(3\).", + ): + nt.permute(0, 2, 1) + with self.assertRaisesRegex( + ValueError, r"permute\(\): duplicate dims are not allowed." + ): + nt.permute(0, 2, -2, 3) + with self.assertRaisesRegex( + ValueError, "Permute is not supported on the batch dimension for jagged NT" + ): + nt.permute(1, 0, 2, 3) + nt_permute = nt.permute(0, 2, 1, -1) + self.assertEqual( + nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3]) + ) + self.assertEqual( + nt_permute.values().shape, + (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]), + ) + self.assertEqual(nt_permute._ragged_idx, 2) + self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt) + def test_to_dtype(self, device): nt = random_nt_from_dims( [2, None, 3], device, torch.float32, layout=torch.jagged @@ -6680,6 +6733,38 @@ def get_values(): self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad) self.assertEqual(v16_dense_eager.grad, v16_nt_compile.grad) + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Platform doesn't support flash or mem-efficient attention", + ) + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @onlyCUDA + @skipIfTorchDynamo() + def test_sdpa_flop_counter(self, device): + from torch.utils.flop_counter import FlopCounterMode + + def get_flops(nt): + flop_counter = FlopCounterMode(display=False) + with flop_counter: + ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt) + ret.values().sum().backward() + return flop_counter.get_total_flops() + + values = torch.randn( + (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16 + ) + offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) + nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16) + + values_meta = torch.randn( + (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16 + ) + offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32) + nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16) + + self.assertEqual(get_flops(nt), get_flops(nt_meta)) + @skipIfTorchDynamo() def test_nested_tensor_activation_checkpoint(self, device): values = torch.randn( @@ -7088,9 +7173,188 @@ def test_unbind_backward(self, device, dtype): a, b, c = nt.unbind() b.sum().backward() - expected_grad = torch.zeros_like(nt) - expected_grad.unbind()[1].add_(1.0) - torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad) + @torch._dynamo.disable + def check(nt): + expected_grad = torch.zeros_like(nt) + expected_grad.unbind()[1].add_(1.0) + self.assertEqual(nt.grad, expected_grad) + + check(nt) + + @dtypes(torch.float32, torch.double, torch.half) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) + + nt = torch.nested.nested_tensor( + [ + torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, + ) + + PADDING_VAL = 4.2 + expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL) + for i, component in enumerate(nt.unbind()): + expected_padded[i, : component.shape[0]].copy_(component) + + padded = nt.to_padded_tensor(PADDING_VAL) + self.assertEqual(expected_padded, padded) + + # convert padded dense -> NJT + from torch.nested._internal.nested_tensor import nested_from_padded + + nt2 = nested_from_padded(padded, nt.offsets()) + self.assertEqual(nt, nt2) + + if requires_grad: + # ensure gradients flow through conversions + nt2.backward(torch.ones_like(nt2)) + self.assertEqual(nt.grad, torch.ones_like(nt)) + + # blows up due to test parametrization otherwise + @torch._dynamo.utils.disable_cache_limit() + @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @dtypes(torch.float32, torch.double, torch.half) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) + + nt = torch.nested.nested_tensor( + [ + torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, + ) + + def f(x): + return x.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + @torch.compile(fullgraph=True) + def g(nt): + def _g(nt): + PADDING_VAL = 4.2 + padded = nt.to_padded_tensor(PADDING_VAL) + padded = f(padded) + # NB: sum_S must be specified to use the lowering for dense -> jagged + # and get full fusion + return nested_from_padded( + padded, nt.offsets(), sum_S=nt.values().shape[0] + ) + + # NB: use checkpointing to force fusion + return torch.utils.checkpoint.checkpoint(_g, nt, use_reentrant=False) + + expected_output = f(nt) + if requires_grad: + expected_output.backward(torch.ones_like(expected_output)) + expected_grad = nt.grad.clone().detach() + nt.grad = None + + from torch._inductor.utils import run_and_get_code + + compiled_output, generated_code = run_and_get_code(g, nt) + if requires_grad: + compiled_output.backward(torch.ones_like(compiled_output)) + compiled_grad = nt.grad.clone().detach() + self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3) + + self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3) + + # === Verify that computation fusion happens. === + # Fallback op call -> fusion didn't happen. + fallback_op_calls_present = any( + "torch.ops.aten._padded_dense_to_jagged_forward.default(" + in generated_code[i] + or "torch.ops.aten._jagged_to_padded_dense_forward.default(" + in generated_code[i] + for i in range(len(generated_code)) + ) + + # NB: Fusion isn't supported on CPU. + self.assertEqual("cuda" in device, not fallback_op_calls_present) + + for i in range(len(generated_code)): + # Examine buffer construction lines in the generated code to determine + # whether fusion occurred. If fusion happens, a 3D buffer with shape + # (B, max_seqlen, D) should never be materialized. + buffer_constructions = [ + line.strip() + for line in generated_code[i].split("\n") + if "empty_strided_cuda(" in line + ] + + buffer_dims = [ + # buffer dim == number of elements in the tensor size tuple arg + len(ast.parse(t).body[0].value.args[0].elts) + for t in buffer_constructions + ] + + if "cuda" in device: + self.assertFalse(any(d == 3 for d in buffer_dims)) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_padded_dense_conversion_preserves_metadata_cache( + self, device, dtype + ): + # shape (B, *, D) + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # expect min / max seqlen to be stored here + cache = dict(nt._metadata_cache) + + @torch.compile + def g(nt): + padded = nt.to_padded_tensor(0.3) + intermediate = padded.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + return nested_from_padded( + intermediate, + nt.offsets(), + min_seqlen=nt._min_seqlen, + max_seqlen=nt._max_seqlen, + sum_S=nt.values().shape[0], + ) + + output = g(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(output._metadata_cache, cache) FORWARD_FAILURES = { @@ -7159,8 +7423,6 @@ def test_unbind_backward(self, device, dtype): "jiterator_binary", "jiterator_binary_return_by_ref", "jiterator_unary", - # Bug found: sum() with keepdim=True returns invalid shape - "sum", # RuntimeError: prod(): keepdim=True must be set for NestedTensor "prod", # RuntimeError: "jagged_to_padded_dense" not implemented for 'Bool' @@ -7194,6 +7456,8 @@ def test_unbind_backward(self, device, dtype): "clone", # Calling into torch.ops.aten.size directly "masked_select", + # NotImplementedError: aten._nested_sum_backward.default. Need to fix the backward pass. + "sum", } COMPILE_FORWARD_FAILURES = { @@ -7201,6 +7465,9 @@ def test_unbind_backward(self, device, dtype): # clone() on non-contiguous with holes NJTs currently use unbind(), leading to # data-dependent error in torch.compile "clone", + # torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default + # for inputs where min / max seqlen are not cached + "sum", } COMPARE_TENSOR_COMPONENT_EQUALITY = { diff --git a/test/test_nn.py b/test/test_nn.py index eb4ccd76515309..7d0241f97dc210 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -39,10 +39,11 @@ from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input -from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \ +from torch.testing._internal.common_device_type import dtypesIfMPS, instantiate_device_type_tests, dtypes, \ dtypesIfCUDA, precisionOverride, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \ skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, \ - onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types + onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, expectedFailureMPS, \ + skipMeta, get_all_device_types from hypothesis import given import torch.testing._internal.hypothesis_utils as hu @@ -3077,6 +3078,7 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) + @skipIfRocm(msg='Large numerical errors') def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -4828,6 +4830,28 @@ def helper(self, mod, size, dtype, mixed_dtype=False, format=torch.channels_last mixed_dtype = False helper(self, nn.BatchNorm3d, shape, dtype, mixed_dtype, torch.channels_last_3d, precisons[dtype]) + def test_batchnorm_half_overflow(self): + def helper(self, mod, size, format): + channels = size[1] + input = torch.randn(size, dtype=torch.half, device='cpu', requires_grad=True) + input = input.contiguous(memory_format=format) + bn = mod(channels).cpu().to(torch.half) + out = bn(input) + + ref_bn = mod(channels).cpu().to(torch.float) + ref_bn.load_state_dict(bn.to(torch.float).state_dict()) + ref_out = ref_bn(input) + + self.assertFalse(out.isinf().any()) + self.assertFalse(out.isnan().any()) + self.assertEqual(out, ref_out) + + for format in [torch.contiguous_format, torch.channels_last]: + helper(self, nn.BatchNorm2d, (4, 80, 500, 500), format) + + for format in [torch.contiguous_format, torch.channels_last_3d]: + helper(self, nn.BatchNorm3d, (4, 80, 20, 100, 100), format) + @parametrize_test( 'bn_module', [ @@ -6201,7 +6225,7 @@ def test_affine_grid(self): warnings.simplefilter("always") # python2 requires this so other tests can trigger self.assertTrue(gradcheck( lambda inp: F.affine_grid(inp, sz, align_corners=align_corners), - (inp,))) + (inp,), check_forward_ad=True)) # test CPU against CUDA if TEST_CUDA: @@ -6253,7 +6277,7 @@ def test_affine_grid_3d(self): warnings.simplefilter("always") # python2 requires this so other tests can trigger self.assertTrue(gradcheck( lambda inp: F.affine_grid(inp, sz, align_corners=align_corners), - (inp,))) + (inp,), check_forward_ad=True)) # test CPU against CUDA if TEST_CUDA: @@ -7919,6 +7943,7 @@ def _test_module_empty_inputs(self, module, inputs): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off() @bf32_on_and_off() def test_affine_2d_rotate0(self, device): @@ -7959,6 +7984,7 @@ def test_affine_2d_rotate0(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.001) @bf32_on_and_off(0.001) def test_affine_2d_rotate90(self, device): @@ -8008,6 +8034,7 @@ def test_affine_2d_rotate90(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) @bf32_on_and_off(0.005) def test_affine_2d_rotate45(self, device): @@ -8085,6 +8112,7 @@ def test_avg_pool_large_tensor2(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) @bf32_on_and_off(0.005) def test_affine_2d_rotateRandom(self, device): @@ -8137,6 +8165,7 @@ def test_affine_2d_rotateRandom(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @expectedFailureMPS # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764 @tf32_on_and_off(0.005) @bf32_on_and_off(0.005) def test_affine_3d_rotateRandom(self, device): @@ -8200,6 +8229,7 @@ def test_batchnorm_large_batch(self, device, dtype): out = bn(data).sum().backward() @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128) + @dtypesIfMPS(torch.float, torch.half, torch.complex64) @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) def test_conv_empty_input(self, device, dtype): def help(input, conv, memory_format): @@ -8576,6 +8606,7 @@ def test_ReplicationPad_empty(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'padding size is expected to be 6'): torch._C._nn.replication_pad3d(torch.randn([2]), padding=[]) + @expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447 def test_ReplicationPad1d_large(self, device): shapes = ([2, 65736, 4], [65736, 2, 4]) pl, pr = 3, 4 @@ -8600,6 +8631,7 @@ def test_ReplicationPad1d_large(self, device): self.assertEqual(x.grad[:, :, 0], g[:, :, : pl + 1].sum(-1)) self.assertEqual(x.grad[:, :, -1], g[:, :, -pr - 1:].sum(-1)) + @expectedFailureMPS # Correctness issue https://github.com/pytorch/pytorch/issues/135447 def test_ReplicationPad2d_large(self, device): shapes = ([2, 65736, 4, 4], [65736, 2, 4, 4]) pl, pr, pt, pb = 3, 4, 5, 6 @@ -8938,9 +8970,11 @@ def check_rnn_grads(rnn1, rnn2): else: self.assertEqual(hx.grad, hx_device.grad) - def test_BatchNorm_empty(self, device): + @dtypesIfMPS(torch.float) + @dtypes(torch.double) + def test_BatchNorm_empty(self, device, dtype): mod = torch.nn.BatchNorm2d(3).to(device) - inp = torch.randn(0, 3, 2, 2, device=device) + inp = torch.randn(0, 3, 2, 2, device=device, dtype=dtype) _test_module_empty_input(self, mod, inp) if self.device_type == 'cuda' and self.has_cudnn(): with torch.backends.cudnn.flags(enabled=False): @@ -8967,7 +9001,7 @@ def test_linear_empty(self, device): def test_one_hot(self, device): # cuda throws device assert for invalid data # xla ignores out of bound indices - if self.device_type != 'cuda' and self.device_type != 'xla': + if self.device_type not in ('cuda', 'mps', 'xla'): with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) @@ -9016,6 +9050,7 @@ def test_one_hot(self, device): with self.assertRaises(RuntimeError): torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) + @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764 def test_nn_empty(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_scalars(input, output): @@ -9031,6 +9066,7 @@ def verify_scalars(input, output): output = m(input) verify_scalars(input, output) + @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise https://github.com/pytorch/pytorch/issues/77764 def test_nn_scalars(self, device): # One off tests to ensure scalars from nn.yaml are properly applied def verify_scalars(input, output): @@ -9209,6 +9245,9 @@ def func(device): # We don't want to make propagating NaN a hard requirement on ops, but for # these easy ones, we should make them do so. + # MPS: NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764 + # MPS: NotImplementedError: aten::hardshrink.out https://github.com/pytorch/pytorch/issues/77764 + @expectedFailureMPS def test_nonlinearity_propagate_nan(self, device): def test(nonlinearity, *args, **kwargs): x = torch.tensor([nan], device=device) @@ -9241,6 +9280,7 @@ def test(nonlinearity, *args, **kwargs): test('threshold', 3, 2) test('threshold', 3, 2, inplace=True) + @expectedFailureMPS # TypeError: float64 the MPS framework doesn't support float64 @parametrize_test("mode", ["nearest-exact", "nearest"]) def test_upsamplingNearest1d(self, device, mode): # Forward AD does not support XLA because XLA tensors don't have storage @@ -9319,6 +9359,7 @@ def test_upsamplingNearestExact1d_rescale(self, device): expected_out = in_t.repeat_interleave(2, dim=-1) self.assertEqual(out_t, expected_out) + @skipIfMps # Partially passes https://github.com/pytorch/pytorch/issues/134430 @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact1d_correctness(self, device, isize, osize): # Here we check if output matches Scikit-Image/Scipy-like result @@ -9337,6 +9378,7 @@ def test_upsamplingNearestExact1d_correctness(self, device, isize, osize): expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @parametrize_test("mode", ["nearest", "nearest-exact"]) def test_upsamplingNearest2d(self, device, memory_format, mode): @@ -9425,6 +9467,7 @@ def test_upsamplingNearest2d_correctness(self, device, memory_format, isize, osi expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @skipIfMps # Partially passes https://github.com/pytorch/pytorch/issues/134430 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize, osize): @@ -9448,6 +9491,7 @@ def test_upsamplingNearestExact2d_correctness(self, device, memory_format, isize expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) @parametrize_test("mode", ["nearest", "nearest-exact"]) def test_upsamplingNearest3d(self, device, memory_format, mode): @@ -9521,6 +9565,7 @@ def test_upsamplingNearest3d_correctness(self, device, memory_format, isize, osi expected_out = expected_out.to(device=device) self.assertEqual(out_t, expected_out) + @expectedFailureMPS # NotImplementedError: aten::_upsample_nearest_exact3d.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) @parametrize_test("isize, osize", [(20, 11), (10, 15)]) def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize, osize): @@ -9643,6 +9688,7 @@ def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_cha else: _ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias) + @expectedFailureMPS # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): # NOTE: We expand the batch dim such that `b*c` is above the maximum @@ -9663,6 +9709,8 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): t_out = F.interpolate(t_in, size=(2, 2), mode="bilinear", align_corners=False, antialias=True) self.assertEqual(expected_out.expand([*shape[:2], 2, 2]), t_out) + # Partially passes. NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 + @skipIfMps @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @parametrize_test("mode", ["bilinear", "bicubic"]) @parametrize_test("antialias", [True, False]) @@ -9768,6 +9816,7 @@ def test_upsamplingBiLinear2d_consistency_interp_size_bug(self, device, memory_f ) torch.testing.assert_close(output_f32, output_ui8, atol=1, rtol=0) + @expectedFailureMPS # NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 def test_upsamplingBicubic2d_correctness(self, device): # test output against known input: align_corners=False result must match opencv in_t = torch.arange(8., device=device).view(1, 2, 2, 2) @@ -9785,6 +9834,7 @@ def test_upsamplingBicubic2d_correctness(self, device): torch.set_printoptions(precision=5) self.assertEqual(out_t, expected_out_t, atol=1e-5, rtol=0) + @expectedFailureMPS # NotImplementedError: aten::_upsample_bicubic2d_aa.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8) @@ -9801,6 +9851,7 @@ def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True) self.assertEqual(expected_out, t_out) + @expectedFailureMPS # NotImplementedError: aten::upsample_trilinear3d.out https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("align_corners", [True, False]) @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last_3d]) def test_upsamplingTrilinear3d(self, device, align_corners, memory_format): @@ -10483,8 +10534,8 @@ def _test_gumbel_softmax_grad(self, device, dtype): tol = 2 * torch.finfo(dtype).eps self.assertEqual(logits_soft.grad, logits_hard.grad, atol=tol, rtol=0) - @skipIfMps @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypesIfMPS(torch.float) @dtypes(torch.float, torch.double) def test_gumbel_softmax(self, device, dtype): self._test_gumbel_softmax_st_shapes(device, dtype, shape=[5], dim=0, count_expected=1) @@ -10512,6 +10563,7 @@ def _test_rnn_retain_variables(self, device, dtype): self.assertEqual(grads, grads2) @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypesIfMPS(torch.half, torch.float) @dtypes(torch.double) def test_rnn_retain_variables(self, device, dtype): self._test_rnn_retain_variables(device, dtype) @@ -10561,6 +10613,7 @@ def flatten_out(mod, inp): # Merge into OpInfo? @skipMeta # LSTM cell reuses output which was resized + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @dtypes(torch.double) def test_LSTM_grad_and_gradgrad(self, device, dtype): hsize = 4 @@ -10570,6 +10623,7 @@ def test_LSTM_grad_and_gradgrad(self, device, dtype): self._test_rnn_mod(mod, inp) @skipMeta # GRU cell reuses output which was resized + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @dtypes(torch.double) def test_GRU_grad_and_gradgrad(self, device, dtype): hsize = 4 @@ -10738,6 +10792,7 @@ def _assertEqual_list(self, expected, list_to_compare, atol=None, rtol=None): for ele in list_to_compare: self.assertEqual(expected, ele, atol=atol, rtol=rtol) + @expectedFailureMPS # NotImplementedError: aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 @parametrize_test("reduction", ['none', 'mean', 'sum']) @parametrize_test("use_module_form", [True, False]) def test_CTCLoss_no_batch_dim(self, device, reduction, use_module_form): @@ -10868,6 +10923,7 @@ def _test_batchnorm_grad(self, device, dtype=torch.double): _assertGradAndGradgradChecks(self, F.batch_norm, (input, running_mean, running_var, weight, bias, training, 0.1, 0.0001)) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_batchnorm_grad(self, device): self._test_batchnorm_grad(device) @@ -10906,6 +10962,7 @@ def test_layernorm_weight_bias(self): out_zero_bias = torch.layer_norm(input, normalized_shape, data, bias, eps) self.assertEqual(out_none_bias, out_zero_bias) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_hardsigmoid_grad(self, device): inputs = (torch.randn(4, 16, 16, device=device, dtype=torch.double) - 0.5) * 10 inputs.requires_grad = True @@ -11112,6 +11169,7 @@ def test_grid_sample_nan_inf(self, device, dtype): padding_mode=padding_mode, align_corners=False) self.assertEqual(sample, torch.zeros([1, 1, 1, 2], device=device, dtype=dtype)) + @expectedFailureMPS # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 def test_CTCLoss_empty_target(self, device): target_lengths = [0, 0, 0] input_lengths = [50, 50, 50] @@ -11132,6 +11190,7 @@ def test_CTCLoss_empty_target(self, device): # Merge into OpInfo? @skipCUDAIf(True, """Test is flaky on Linux and Windows, typical error message: https://github.com/pytorch/pytorch/issues/34870""") + @expectedFailureMPS # NotImplementedError aten::_ctc_loss https://github.com/pytorch/pytorch/issues/77764 def test_ctc_loss(self, device): batch_size = 64 num_labels = 101 @@ -11234,6 +11293,7 @@ def test_ctc_loss_cudnn_tensor(self, device): grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out) self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0) + @expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS. @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) @tf32_on_and_off(0.005) @@ -11490,8 +11550,8 @@ def test_cross_entropy_64bit(self, device, reduction): print(logits.numel(), labels.numel(), loss.numel()) self.assertTrue(torch.allclose(loss_cpu, loss.cpu(), rtol=1e-4, atol=1e-4)) - def _nll_loss_helper(self, input_size, reduction, expected, device): - input = torch.rand(input_size, requires_grad=True, device=device) + def _nll_loss_helper(self, input_size, reduction, expected, device, dtype): + input = torch.rand(input_size, requires_grad=True, device=device, dtype=dtype) num_channels = input_size[1] target_size = (input_size[0], ) + tuple(input_size[2:]) target = torch.randint(num_channels, target_size, device=device) @@ -11502,28 +11562,34 @@ def _nll_loss_helper(self, input_size, reduction, expected, device): output.sum().backward() self.assertEqual(input.grad.size(), input.size()) - def test_nll_loss_empty_tensor_reduction_none(self, device): - self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device) - self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device) - self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device) - self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device) - self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device) - - def test_nll_loss_empty_tensor_reduction_mean(self, device): + @dtypesIfMPS(torch.half, torch.float) + @dtypes(torch.float) + def test_nll_loss_empty_tensor_reduction_none(self, device, dtype): + self._nll_loss_helper([0, 3], "none", torch.empty([0], device=device), device, dtype) + self._nll_loss_helper([0, 3, 5, 7], "none", torch.empty([0, 5, 7], device=device), device, dtype) + self._nll_loss_helper([2, 3, 0, 7], "none", torch.empty([2, 0, 7], device=device), device, dtype) + self._nll_loss_helper([2, 3, 5, 0], "none", torch.empty([2, 5, 0], device=device), device, dtype) + self._nll_loss_helper([2, 3, 5, 7, 0], "none", torch.empty([2, 5, 7, 0], device=device), device, dtype) + + @dtypesIfMPS(torch.half, torch.float) + @dtypes(torch.float) + def test_nll_loss_empty_tensor_reduction_mean(self, device, dtype): nan = torch.tensor(float('nan'), device=device) - self._nll_loss_helper([0, 3], "mean", nan, device) - self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device) - self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device) - self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device) - self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device) + self._nll_loss_helper([0, 3], "mean", nan, device, dtype) + self._nll_loss_helper([0, 3, 5, 7], "mean", nan, device, dtype) + self._nll_loss_helper([2, 3, 0, 7], "mean", nan, device, dtype) + self._nll_loss_helper([2, 3, 5, 0], "mean", nan, device, dtype) + self._nll_loss_helper([2, 3, 5, 7, 0], "mean", nan, device, dtype) - def test_nll_loss_empty_tensor_reduction_sum(self, device): + @dtypesIfMPS(torch.half, torch.float) + @dtypes(torch.float) + def test_nll_loss_empty_tensor_reduction_sum(self, device, dtype): zero = torch.tensor(0, device=device) - self._nll_loss_helper([0, 3], "sum", zero, device) - self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device) - self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device) - self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device) - self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device) + self._nll_loss_helper([0, 3], "sum", zero, device, dtype) + self._nll_loss_helper([0, 3, 5, 7], "sum", zero, device, dtype) + self._nll_loss_helper([2, 3, 0, 7], "sum", zero, device, dtype) + self._nll_loss_helper([2, 3, 5, 0], "sum", zero, device, dtype) + self._nll_loss_helper([2, 3, 5, 7, 0], "sum", zero, device, dtype) def test_nll_loss_total_weight_is_zero(self, device): @@ -11722,6 +11788,7 @@ def test_cross_entropy_label_smoothing_errors(self, device): r"label_smoothing must be between 0\.0"): loss(*input_arg) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 @set_default_dtype(torch.double) def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device): N, C = 10, 4 @@ -11877,6 +11944,7 @@ def test_softshrink_negative(self, device): r'lambda must be greater or equal to 0, but found to be -1\.'): m(input) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_fold(self, device): def test_dtype(fn, input, dtype): input = input.detach().clone().to(dtype=dtype).requires_grad_(True) @@ -11915,6 +11983,7 @@ def test_logsigmoid_out(self, device): # Check that clip_grad_norm_ raises an error if the total norm of the # parameters' gradients is non-finite + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_clip_grad_norm_error_if_nonfinite(self, device): norms_pos = [0.1, 1, 2, 3.5, inf] norms_neg = [-0.1, -1, -2, -3.5] @@ -12041,7 +12110,8 @@ def __init__(self) -> None: self.assertEqual(p.grad.to(devices[0]), pe.grad) def test_elu_inplace_overlap(self, device): - x = torch.randn((1, 6), dtype=torch.bfloat16, device=device).expand((6, 6)) + dtype = torch.bfloat16 if device != 'mps:0' else torch.float16 + x = torch.randn((1, 6), dtype=dtype, device=device).expand((6, 6)) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.elu(x, inplace=True) with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): @@ -12082,6 +12152,7 @@ def test_softplus_inplace_overlap(self, device): with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): F.softplus(x, out=x) + @expectedFailureMPS # TypeError: the MPS framework doesn't support float64 def test_softplus_low_threshold(self, device): # Ensure gradients are computed correctly with a low threshold. model = torch.nn.Softplus(threshold=1).double() @@ -12103,6 +12174,7 @@ def test_leaky_relu_inplace_overlap(self, device): F.leaky_relu_(x) # Merge into OpInfo? + @expectedFailureMPS # NotImplementedError: aten::rrelu_with_noise_ https://github.com/pytorch/pytorch/issues/77764 def test_leaky_relu_inplace_with_neg_slope(self, device): a = torch.tensor([-1., 1.], device=device, requires_grad=True) b = torch.nn.functional.leaky_relu_(a.clone(), -2) @@ -12122,10 +12194,11 @@ def test_leaky_relu_inplace_with_zero_slope(self, device): expected = torch.tensor([0., 0., 1.], device=device) self.assertEqual(a.grad, expected) - a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=torch.bfloat16, requires_grad=True) + dtype = torch.bfloat16 if device != 'mps:0' else torch.float16 + a_bf16 = torch.tensor([-2., 0., 2.], device=device, dtype=dtype, requires_grad=True) b_bf16 = torch.nn.functional.leaky_relu_(a_bf16.clone(), 0.0) b_bf16.backward(torch.ones(3, device=device)) - expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=torch.bfloat16) + expected_bf16 = torch.tensor([0., 0., 1.], device=device, dtype=dtype) self.assertEqual(a_bf16.grad, expected_bf16) @onlyCPU @@ -12238,15 +12311,13 @@ def cosine_distance(x, y): self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6) self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6) - def test_to_complex(self, device): + @dtypesIfMPS(torch.cfloat, torch.float) + @dtypes(torch.cfloat, torch.cdouble, torch.float) + def test_to_complex(self, device, dtype): m = nn.Linear(3, 5).to(device) self.assertIs(m, m.to(device)) - m.to(torch.cfloat) - self.assertIs(m.weight.dtype, torch.cfloat) - m.to(torch.cdouble) - self.assertIs(m.weight.dtype, torch.cdouble) - m.to(torch.float) - self.assertIs(m.weight.dtype, torch.float) + m.to(dtype) + self.assertIs(m.weight.dtype, dtype) with warnings.catch_warnings(record=True) as w: # Trigger warning m.to(torch.cfloat) @@ -12255,6 +12326,7 @@ def test_to_complex(self, device): self.assertTrue("Complex modules are a new feature" in str(w[-1].message)) @skipMeta + @dtypesIfMPS(torch.float32) @dtypes(torch.float32, torch.float64) def test_module_to_empty(self, device, dtype): class MyModule(nn.Module): @@ -12338,6 +12410,8 @@ def test_skip_init(self, device): self.assertEqual(m_initialized.weight.device, m_uninitialized.weight.device) self.assertFalse(torch.allclose(m_initialized.weight, m_uninitialized.weight)) + @skipIfRocm(msg='Not our bug: TransformerEncoderLayer._sa_block still uses FA/ME and effectively takes fastpath') + @skipIfMps # TODO(hvaara): Investigate as possible bug. macOS 13 passes, while 14 and 15 fails. @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) def test_transformerencoderlayer(self, device, dtype): @@ -12643,6 +12717,7 @@ def perm_fn(x): with cm: _test(activation=activation, batch_first=batch_first, training=training) + @skipIfMps # RuntimeError: foreach=True was passed, but can't use the foreach API on mps tensors @parametrize_test('foreach', (False, True)) def test_clip_grad_value(self, foreach, device): if torch.device(device).type == 'xla' and foreach: @@ -12670,6 +12745,7 @@ def test_clip_grad_value(self, foreach, device): clip_grad_value_([p2], clip_value, foreach=foreach) self.assertEqual(p1.grad, p2.grad) + @skipIfMps # TypeError: the MPS framework doesn't support float64 @parametrize_test('foreach', (False, True)) @parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf')) def test_clip_grad_norm(self, norm_type, foreach, device): @@ -12768,6 +12844,7 @@ def test_adaptiveavg_pool1d_shmem(self, device): self.assertEqual(x.grad, x_cpu.grad) @skipMeta + @expectedFailureMPS # NotImplementedError: aten::channel_shuffle https://github.com/pytorch/pytorch/issues/77764 def test_channel_shuffle(self, device): # 3D tensor x = torch.tensor( @@ -12940,7 +13017,8 @@ def __init__(self) -> None: self.assertEqual(list(state_dict.keys()), list(ddp_state_dict.keys())) self.assertEqual(list(state_dict._metadata.keys()), list(ddp_state_dict._metadata.keys())) -instantiate_device_type_tests(TestNNDeviceType, globals()) + +instantiate_device_type_tests(TestNNDeviceType, globals(), allow_mps=True) instantiate_parametrized_tests(TestNN) if __name__ == '__main__': diff --git a/test/test_ops.py b/test/test_ops.py index e3627d62ab94b3..e1b5ebfe87d248 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2476,9 +2476,16 @@ def map_to_fake(e): # if you see a shape exception here, you may need to add # a `dynamic_output_shape` tag to an operator - # prims/decomps must correctly model strides, - # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 - prims.utils.compare_tensor_meta(fake_out, real_out, True) + if op.op not in [ + torch.ops.aten._efficient_attention_forward, + torch.ops.aten._flash_attention_forward, + ]: + # prims/decomps must correctly model strides, + # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325 + + # note: the excluded ops have intentionally incorrect device; + # see "Note [Seed and Offset]" (_meta_registrations.py) + prims.utils.compare_tensor_meta(fake_out, real_out, True) if name not in aliasing_failures: fake_aliasing = outputs_alias_inputs( diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 1a35e6823a189e..255b177a10eda7 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1358,8 +1358,8 @@ def forward(self, crop_camera_1, mask_1): mul_4 = sym_size_int * 3 view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None - view_4 = torch.ops.aten.view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None - index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_4); crop_camera_1 = mask_1 = view_4 = index_put_ = None + _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None + index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None return None""") # noqa: B950 def test_unbacked_slice(self): diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 5433540aeb2b6a..e2d7d8236c6159 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -3,13 +3,14 @@ import importlib import inspect import json +import logging import os import pkgutil import unittest from typing import Callable import torch -from torch._utils_internal import get_file_path_2 +from torch._utils_internal import get_file_path_2 # @manual from torch.testing._internal.common_utils import ( IS_JETSON, IS_MACOS, @@ -20,6 +21,9 @@ ) +log = logging.getLogger(__name__) + + class TestPublicBindings(TestCase): def test_no_new_reexport_callables(self): """ @@ -267,7 +271,9 @@ def test_modules_can_be_imported(self): failures = [] def onerror(modname): - failures.append((modname, ImportError)) + failures.append( + (modname, ImportError("exception occurred importing package")) + ) for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror): modname = mod.name @@ -279,8 +285,8 @@ def onerror(modname): importlib.import_module(modname) except Exception as e: # Some current failures are not ImportError - - failures.append((modname, type(e))) + log.exception("import_module failed") + failures.append((modname, e)) # It is ok to add new entries here but please be careful that these modules # do not get imported by public code. @@ -441,14 +447,16 @@ def onerror(modname): } errors = [] - for mod, excep_type in failures: + for mod, exc in failures: if mod in public_allowlist: # TODO: Ensure this is the right error type continue if mod in private_allowlist: continue - errors.append(f"{mod} failed to import with error {excep_type}") + errors.append( + f"{mod} failed to import with error {type(exc).__qualname__}: {str(exc)}" + ) self.assertEqual("", "\n".join(errors)) # AttributeError: module 'torch.distributed' has no attribute '_shard' diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 7c938e302118c3..00e86b1f9977b2 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1793,6 +1793,23 @@ def test_tolist_numpy_with_torch_dispatch_mode(self) -> None: with self.assertRaises(AssertionError): self.assertEqual(x, None) + # See https://github.com/pytorch/pytorch/issues/136064 + def test_view_returns_alias_under_torch_dispatch(self): + class MyMode(TorchDispatchMode): + def __init__(self, testcase): + self.testcase = testcase + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = func(*args, **kwargs) + if func == torch.ops.aten.view.dtype: + # view should return a fresh TensorImpl + self.testcase.assertTrue(out is not args[0]) + return out + + with MyMode(self): + x = torch.ones(4, dtype=torch.float32) + out = x.view(torch.float32) + def test_record_stream(self) -> None: class TestMode(TorchDispatchMode): def __init__(self, testcase): diff --git a/test/test_reductions.py b/test/test_reductions.py index 8e07692cb58931..c31408d8de6f19 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -487,10 +487,14 @@ def test_dim_reduction_lastdim(self, device, dtype): self.assertEqual(y, y2) @skipIfNoSciPy - def test_logsumexp(self, device): + @dtypes(torch.float32, torch.double, torch.complex64, torch.complex128) + def test_logsumexp(self, device, dtype): from scipy.special import logsumexp - a = torch.randn(5, 4, device=device) - a[0, 0] = inf + a = torch.randn(5, 4, device=device, dtype=dtype) + # torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp, + # numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740 + if torch.device(device) != torch.device('cpu'): + a[0, 0] = inf a[1, :] = -inf actual = a.logsumexp(1) expected = logsumexp(a.cpu().numpy(), 1) @@ -498,11 +502,14 @@ def test_logsumexp(self, device): self.assertEqual(expected, actual) # check that out is actually inplace - b = torch.zeros(5, 2, device=device) + b = torch.zeros(5, 2, device=device, dtype=dtype) c = b[:, 0] torch.logsumexp(a, 1, out=c) self.assertEqual(expected, b[:, 0]) + @skipIfNoSciPy + def test_logsumexp_integral_promotion(self, device): + from scipy.special import logsumexp # check integral inputs is promoted to floating point e = torch.randint(-100, 100, [5, 4], device=device) actual = e.logsumexp(1).to(torch.float64) @@ -1498,7 +1505,13 @@ def test_prod_lowp(self, device, dtype): self.assertEqual(res1, res2.to(dtype=dtype)) def test_prod_bool(self, device): - vals = [[True, True], [True, False], [False, False], []] + vals = [ + [True, True], + [True, False], + [False, False], + [], + [False] * 256, # https://github.com/pytorch/pytorch/issues/127866 + ] for val in vals: result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item() expect = np.prod(np.array(val), dtype=bool) diff --git a/test/test_serialization.py b/test/test_serialization.py index f663188a959da8..3ba96b80541d88 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,5 +1,6 @@ # Owner(s): ["module: serialization"] +import contextlib import copy import gc import gzip @@ -19,6 +20,7 @@ from pathlib import Path import torch +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -27,6 +29,7 @@ LoadEndianness, safe_globals, set_default_load_endianness, + skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4052,7 +4055,7 @@ def test_serialization_warning_s390x(self): @parametrize('path_type', (str, Path)) @parametrize('weights_only', (True, False)) @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows") - def test_serialization_mmap_loading(self, weights_only, path_type): + def test_serialization_mmap_loading_options(self, weights_only, path_type): class DummyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -4101,7 +4104,7 @@ def forward(self, input): for v in result.values(): self.assertTrue(v.is_cuda) - def test_serialization_mmap_loading_options(self): + def test_serialization_mmap_loading(self): if IS_WINDOWS: with self.assertRaisesRegex(RuntimeError, "Changing the default mmap options is currently not supported"): torch.serialization.set_default_mmap_options(2) @@ -4111,22 +4114,36 @@ def test_serialization_mmap_loading_options(self): with tempfile.NamedTemporaryFile() as f: torch.save(sd, f) # with MmapVisibility.MAP_PRIVATE, should not be able to modify file - sd_loaded = torch.load(f.name, mmap=True) + sd_loaded = torch.load(f.name, mmap=True, weights_only=True) sd_loaded['weight'][0][0] = 0 - sd_loaded2 = torch.load(f.name, mmap=True) + sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True) self.assertEqual(sd_loaded2['weight'], sd['weight']) # with MmapVisibility.MAP_SHARED, should be able to modify file torch.serialization.set_default_mmap_options(MAP_SHARED) try: - sd_loaded = torch.load(f.name, mmap=True) + sd_loaded = torch.load(f.name, mmap=True, weights_only=True) sd_loaded['weight'][0][0] = 0 - sd_loaded2 = torch.load(f.name, mmap=True) + sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True) self.assertNotEqual(sd_loaded2['weight'], sd['weight']) self.assertEqual(sd_loaded2['weight'][0][0].item(), 0) self.assertEqual(sd_loaded2['weight'], sd_loaded['weight']) finally: torch.serialization.set_default_mmap_options(MAP_PRIVATE) + @unittest.skipIf(IS_WINDOWS, "mmap ctx doesn't work on Windows") + def test_serialization_mmap_loading_ctx(self): + sd = torch.nn.Linear(3, 5).state_dict() + with tempfile.NamedTemporaryFile() as f: + torch.save(sd, f) + with torch.serialization.set_default_mmap_options(MAP_SHARED): + sd_loaded = torch.load(f.name, mmap=True, weights_only=True) + sd_loaded['weight'][0][0] = 0 + sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True) + self.assertNotEqual(sd_loaded2['weight'], sd['weight']) + self.assertEqual(sd_loaded2['weight'][0][0].item(), 0) + self.assertEqual(sd_loaded2['weight'], sd_loaded['weight']) + self.assertTrue(torch.serialization.get_default_mmap_options() == MAP_PRIVATE) + @parametrize('dtype', (torch.float8_e5m2, torch.float8_e4m3fn, torch.complex32)) @parametrize('weights_only', (True, False)) def test_serialization_dtype(self, dtype, weights_only): @@ -4198,6 +4215,91 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization(self, materialize_fake): + # Create one tensor that uses each of the paths in __reduce_ex__ that should work + t_device = "cuda" if torch.cuda.is_available() else "cpu" + t_v2 = torch.randn(2, 3, device=t_device) + t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) + i = torch.tensor([[0, 1, 1], + [2, 0, 2]]) + v = torch.tensor([3, 4, 5], dtype=torch.float32) + if not materialize_fake: + # FakeTensorConverter messes up sizes of i and v for the sparse tensor + st = torch.sparse_coo_tensor(i, v, (2, 4)) + tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) + + mode, converter = FakeTensorMode(), FakeTensorConverter() + + def fn(t): + return converter.from_real_tensor(mode, t) if materialize_fake else t + + sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} + sd_expected = { + 't_v2': torch.zeros(2, 3, device=t_device), + 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), + 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), + } + + if not materialize_fake: + sd['st'] = st + sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) + + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + with safe_globals([TwoTensor]): + sd_loaded = torch.load(f, weights_only=True) + self.assertEqual(sd_loaded, sd_expected, exact_device=True) + self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) + self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) + + # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx + if not materialize_fake: + ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) + with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): + with skip_data(), BytesIOContext() as f: + torch.save(ft, f) + + @parametrize("materialize_fake", (True, False)) + def test_skip_data_serialization_preserves_views(self, materialize_fake): + ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext + with ctx(): + t = torch.randn(2, 3) + t_view = t.view(-1) + t_slice = t[1] + sd = {'t': t, 't_view': t_view, 't_slice': t_slice} + with BytesIOContext() as f: + with skip_data(materialize_fake_tensors=materialize_fake): + torch.save(sd, f) + f.seek(0) + sd_loaded = torch.load(f, weights_only=True) + self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) + + def test_skip_data_serialization_error_cases(self): + def _save_load(t): + with BytesIOContext() as f: + with skip_data(): + torch.save(t, f) + f.seek(0) + torch.load(f, weights_only=True) + + nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) + t = torch.randn(2, 3, device="meta") + with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): + _save_load(nt) + + with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): + _save_load(t) + + with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): + with skip_data(), BytesIOContext() as f: + torch.save(torch.randn(2, 3), f) + f.seek(0) + torch.load(f, weights_only=True) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index df847543600ed2..aebfdaec0cb854 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -402,10 +402,10 @@ def test(shape): if tensor.size() != torch.Size([]): if dtype is torch.bfloat16: expected = torch.from_numpy( - np.msort(tensor.float().cpu().numpy()) + np.sort(tensor.float().cpu().numpy(), axis=0) ).bfloat16() else: - expected = torch.from_numpy(np.msort(tensor.cpu().numpy())) + expected = torch.from_numpy(np.sort(tensor.cpu().numpy(), axis=0)) else: expected = tensor # numpy.msort() does not support empty shapes tensor diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 903b609c4fb236..f897fd041889f8 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -4027,6 +4027,7 @@ def test_TensorAsKey(self, device): @skipIfRocm @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8) + @precisionOverride({torch.float16: 6e-1}) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") def test_triton_kernel(self, op, device, dtype, blocksize): from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm @@ -4039,15 +4040,24 @@ def bsr_dense_linear(input, weights, bias=None): operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op] - def reference(input, mat1, mat2, beta=1, alpha=1, op=op): + def reference(input, mat1, mat2, beta=1, alpha=1, left_alpha=None, right_alpha=None, op=op): assert mat1.layout is torch.strided assert mat2.layout is torch.strided if dtype is torch.int8: if op == '_int_bsr_dense_addmm': - return beta * input + alpha * torch._int_mm(mat1, mat2) - # workaround RuntimeError: "addmm_cuda" not implemented for 'Char' - return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8) - return beta * input + alpha * (mat1 @ mat2) + mat12 = torch._int_mm(mat1, mat2) + else: + # workaround RuntimeError: "addmm_cuda" not implemented for 'Char' + mat12 = torch._int_mm(mat1, mat2).to(torch.int8) + else: + mat12 = mat1 @ mat2 + if alpha != 1: + mat12 *= alpha + if left_alpha is not None: + mat12 = left_alpha.reshape(*left_alpha.shape[:-1], -1, 1) * mat12 + if right_alpha is not None: + mat12 = mat12 * right_alpha.reshape(*right_alpha.shape[:-1], 1, -1) + return beta * input + mat12 if op == '_int_bsr_dense_addmm': # _int_bsr_dense_addmm is same as bsr_dense_addmm except @@ -4056,6 +4066,8 @@ def reference(input, mat1, mat2, beta=1, alpha=1, op=op): # definitions above and all other definitions below are # identical between _int_bsr_dense_addmm and # bsr_dense_addmm. + if dtype.is_floating_point or dtype.is_complex: + self.skipTest(f"Redundant test: {op} on {dtype} tensors") op = 'bsr_dense_addmm' def nc_copy(t, axes=(-1,)): @@ -4101,14 +4113,21 @@ def nc_copy(t, axes=(-1,)): blocks_per_row_lst = [1, 2] blocks_per_col_lst = [1, 2] result_cols_lst = [16, 32, 64] - for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product( - beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst): + has_left_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op] + has_right_alpha_lst = dict(bsr_dense_addmm=[False, True], bsr_dense_mm=[False], bsr_dense_linear=[False])[op] + high = 1.5 + int(dtype is torch.int8) + for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N, has_left_alpha, has_right_alpha in itertools.product( + beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst, + has_left_alpha_lst, has_right_alpha_lst): M = BM * blocks_per_row K = BK * blocks_per_col mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device) bsr = mat1.to_sparse_bsr((BM, BK)) - mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5) - input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5) + mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=high) + input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=high) + + left_alpha = make_tensor(M, dtype=dtype, device=device, low=0.5, high=high) if has_left_alpha else None + right_alpha = make_tensor(N, dtype=dtype, device=device, low=0.5, high=high) if has_right_alpha else None if 0 and op == "bsr_dense_addmm": # Find optimal kernel parameters, the speed-up is @@ -4121,12 +4140,12 @@ def nc_copy(t, axes=(-1,)): meta = get_meta(op, key, version=(0, dtype, 0.5)) if meta is None: optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5) - meta = get_meta(op, key, version=(0, dtype, 0.5)) assert meta is not None dump() # this will update torch/sparse/_triton_ops_meta.py - expected = reference(input, mat1, mat2, beta=beta, alpha=alpha) - kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={}, + expected = reference(input, mat1, mat2, beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha) + kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, + left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_mm={}, bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op] args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2), @@ -4156,7 +4175,7 @@ def nc_copy(t, axes=(-1,)): if op in {'bsr_dense_addmm', 'bsr_dense_linear'}: args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2), bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op] - kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), + kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op] result = operation(*args, **kwargs) self.assertEqual(result, expected) diff --git a/test/test_stateless.py b/test/test_stateless.py index a68d8ad9b5f703..a62e88d2caf034 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -738,6 +738,8 @@ def forward(self, inp, *, other_inp): self.assertEqual(res, other_inp) res_1 = functional_call(mod, a, (), {'inp': inp, 'other_inp': other_inp}) self.assertEqual(res, res_1) + res_2 = functional_call(mod, a, kwargs={'inp': inp, 'other_inp': other_inp}) + self.assertEqual(res, res_2) def test_functional_call_tuple_dicts(self): mod = MockModule() diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 0eb4be644ede65..81ed1126dcbb27 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -1,30 +1,35 @@ # Owner(s): ["oncall: pt2"] +import functools import itertools import math import sys +from typing import Callable, List, Tuple, Type import sympy -from typing import Callable, List, Tuple, Type + +import torch +import torch.fx as fx +from sympy.core.relational import is_ge, is_gt, is_le, is_lt from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( - TEST_Z3, instantiate_parametrized_tests, parametrize, run_tests, + TEST_Z3, TestCase, ) -from torch.utils._sympy.functions import FloorDiv -from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve -from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges -from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis +from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd from torch.utils._sympy.interp import sympy_interp -from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity -from sympy.core.relational import is_ge, is_le, is_gt, is_lt -import functools -import torch.fx as fx - +from torch.utils._sympy.reference import ( + PythonReferenceAnalysis, + ReferenceAnalysis, + TensorReferenceAnalysis, +) +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve +from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges UNARY_OPS = [ @@ -39,10 +44,18 @@ "ceil", ] BINARY_OPS = [ - "truediv", "floordiv", + "truediv", + "floordiv", # "truncdiv", # TODO # NB: pow is float_pow - "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" + "add", + "mul", + "sub", + "pow", + "pow_by_natural", + "minimum", + "maximum", + "mod", ] UNARY_BOOL_OPS = ["not_"] @@ -152,12 +165,12 @@ def test_int_infinity(self): self.assertIs(-(-int_oo), int_oo) # noqa: B002 self.assertIs(abs(int_oo), int_oo) self.assertIs(abs(-int_oo), int_oo) - self.assertIs(int_oo ** 2, int_oo) + self.assertIs(int_oo**2, int_oo) self.assertIs((-int_oo) ** 2, int_oo) self.assertIs((-int_oo) ** 3, -int_oo) - self.assertEqual(int_oo ** -1, 0) + self.assertEqual(int_oo**-1, 0) self.assertEqual((-int_oo) ** -1, 0) - self.assertIs(int_oo ** int_oo, int_oo) + self.assertIs(int_oo**int_oo, int_oo) self.assertTrue(int_oo == int_oo) self.assertFalse(int_oo != int_oo) self.assertTrue(-int_oo == -int_oo) @@ -325,7 +338,9 @@ def test_binary_ref_range(self, fn): class TestSympyInterp(TestCase): - @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) + @parametrize( + "fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS + ) def test_interp(self, fn): # SymPy does not implement truncation for Expressions if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): @@ -335,8 +350,8 @@ def test_interp(self, fn): if fn == "pow_by_natural": is_integer = True - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) + x = sympy.Dummy("x", integer=is_integer) + y = sympy.Dummy("y", integer=is_integer) vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: @@ -358,10 +373,14 @@ def test_interp(self, fn): ref_r = getattr(ReferenceAnalysis, fn)(*sargs) # Yes, I know this is a longwinded way of saying xreplace; the # point is to test sympy_interp - r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) + r = sympy_interp( + ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr + ) self.assertEqual(ref_r, r) - @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) + @parametrize( + "fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS + ) def test_python_interp_fx(self, fn): # These never show up from symbolic_shapes if fn in ("log", "exp"): @@ -383,8 +402,8 @@ def test_python_interp_fx(self, fn): if fn == "pow_by_natural": is_integer = True - x = sympy.Dummy('x', integer=is_integer) - y = sympy.Dummy('y', integer=is_integer) + x = sympy.Dummy("x", integer=is_integer) + y = sympy.Dummy("y", integer=is_integer) symbols = [x] if arity == 2: @@ -411,26 +430,125 @@ def test_python_interp_fx(self, fn): sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) if arity == 1: + def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + return sympy_interp( + PythonReferenceAnalysis, {x: px}, sympy_expr + ) + else: + def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + return sympy_interp( + PythonReferenceAnalysis, {x: px, y: py}, sympy_expr + ) gm = fx.symbolic_trace(trace_f) self.assertEqual( - sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), - gm(*args) + sympy_interp( + PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr + ), + gm(*args), ) + @parametrize( + "fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS + ) + def test_tensor_interp(self, fn): + # Skip operations not implemented or not applicable for tensors + if fn in ("div", "truncdiv", "int_truediv", "mod", "round_decimal"): + return + + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Symbol("x", integer=is_integer) + y = sympy.Symbol("y", integer=is_integer) + + vals = CONSTANTS + if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: + vals = [True, False] + + arity = 1 + if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: + arity = 2 + + symbols = [x] + if arity == 2: + symbols = [x, y] + + for args in itertools.product(vals, repeat=arity): + if arity == 1 and not valid_unary(fn, *args): + continue + elif arity == 2 and not valid_binary(fn, *args): + continue + + with self.subTest(args=args): + tensor_args = [ + torch.tensor( + a, dtype=torch.double if isinstance(a, float) else torch.int64 + ) + for a in args + ] + + try: + tensor_fn = getattr(TensorReferenceAnalysis, fn) + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + direct_result = tensor_fn(*tensor_args) + interp_result = sympy_interp( + TensorReferenceAnalysis, + dict(zip(symbols, tensor_args)), + sympy_expr, + ) + + # Ensure both results are of the same dtype for comparison + if direct_result.dtype != interp_result.dtype: + if ( + direct_result.dtype == torch.bool + or interp_result.dtype == torch.bool + ): + direct_result = direct_result.to(torch.bool) + interp_result = interp_result.to(torch.bool) + else: + direct_result = direct_result.to(torch.double) + interp_result = interp_result.to(torch.double) + + self.assertTrue( + torch.allclose( + direct_result, interp_result, rtol=1e-5, atol=1e-8 + ), + f"Mismatch for {fn}{args}: direct={direct_result}, interp={interp_result}", + ) + + if fn in UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS: + self.assertEqual(direct_result.dtype, torch.bool) + self.assertEqual(interp_result.dtype, torch.bool) + + if fn in ( + "floor_to_int", + "ceil_to_int", + "round_to_int", + "trunc_to_int", + ): + self.assertEqual(direct_result.dtype, torch.int64) + self.assertEqual(interp_result.dtype, torch.int64) + + except NotImplementedError: + print(f"Operation {fn} not implemented for TensorReferenceAnalysis") + except Exception as e: + self.fail(f"Unexpected error for {fn}{args}: {str(e)}") + def type_name_fn(type: Type) -> str: return type.__name__ + def parametrize_relational_types(*types): def wrapper(f: Callable): return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f) + return wrapper @@ -492,7 +610,13 @@ def test_noop_rhs(self, op): self.assertEqual(r_expr, mirror(rhs, lhs)) self.assertEqual(r_rhs, lhs) - def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs): + def _test_cases( + self, + cases: List[Tuple[sympy.Basic, sympy.Basic]], + thing: sympy.Basic, + op: Type[sympy.Rel], + **kwargs, + ): for source, expected in cases: r = try_solve(source, thing, **kwargs) @@ -563,7 +687,7 @@ def test_multiplication_division_inequality(self, op): @parametrize_relational_types() def test_floordiv(self, op): - from sympy import Eq, Ne, Gt, Ge, Lt, Le + from sympy import Eq, Ge, Gt, Le, Lt, Ne a, b, c = sympy.symbols("a b c") pos = sympy.Symbol("pos", positive=True) @@ -603,10 +727,12 @@ def test_floordiv(self, op): r_op = op self._test_cases([special_case, *cases], a, r_op) - self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False) + self._test_cases( + [(special_case[0], None), *cases], a, r_op, floordiv_inequality=False + ) def test_floordiv_eq_simplify(self): - from sympy import Eq, Lt, Le + from sympy import Eq, Le, Lt a = sympy.Symbol("a", positive=True, integer=True) @@ -667,6 +793,24 @@ def test_z3_proof_floordiv_eq_simplify(self): r = solver.check() self.assertEqual(r, z3.unsat) + def test_simple_floordiv_gcd(self): + x, y, z = sympy.symbols("x y z") + + # positive tests + self.assertEqual(simple_floordiv_gcd(x, x), x) + self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128) + self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128) + self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128) + self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x) + self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x) + self.assertEqual(simple_floordiv_gcd(x * y, x), x) + self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y) + self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x) + + # negative tests + self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1) + + class TestSingletonInt(TestCase): def test_basic(self): j1 = SingletonInt(1, coeff=1) diff --git a/test/test_transformers.py b/test/test_transformers.py index d0e3495a686514..168e4f903b0fda 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -347,6 +347,9 @@ def test_train_with_pad_and_catch_error(self, device): @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): + if TEST_WITH_ROCM: + if attn_mask_dim is not None and mask_dtype == torch.bool: + self.skipTest("boolean mask is not fully supported on ROCm yet.") # MHA converts all with torch.no_grad(): B = 2 @@ -429,6 +432,7 @@ def hook(module, inputs, output): # remove hook handle.remove() + @skipIfRocm @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @@ -1585,7 +1589,7 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): q, k, v, None, 0.0, False)) @onlyCUDA - @skipIfRocm # Nested Tensor + @skipIfRocm(msg='enable_gqa=True unsupported') @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel): @@ -1601,7 +1605,7 @@ def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_ke is_causal=False, enable_gqa=True) @onlyCPU - @skipIfRocm # Nested Tensor + @skipIfRocm(msg='enable_gqa=True unsupported') def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device): rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) @@ -1623,6 +1627,8 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype) size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257) + if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels + size = SdpaShape(2, 2, 3, 257) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1665,8 +1671,9 @@ def test_unaligned_tensors(self, device): make_tensor = partial(torch.rand, size, device=device, dtype=dtype) q, k, v = make_tensor(), make_tensor(), make_tensor() with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() + with ctxmgr: + torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware") @@ -1978,8 +1985,8 @@ def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: to @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16]) @parametrize("batch_size", [2, 12]) - @parametrize("q_seq_len", [514, 1030]) - @parametrize("kv_seq_len", [514]) + @parametrize("q_seq_len", [11, 514, 1030]) + @parametrize("kv_seq_len", [17, 514]) @parametrize("n_head", [1, 3]) @parametrize("head_dim", [8]) @parametrize("mask_dim", [2, 4]) @@ -2074,7 +2081,7 @@ def test_scaled_dot_product_fused_attention_mask_vs_math_cpu( self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol) self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol) - def test_scaled_dot_product_fused_attention_with_inf(self, device): + def test_sdpa_with_inf(self, device): # https://github.com/pytorch/pytorch/issues/127055. full = torch.full((600, 600), float("-inf"), device=device) mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10) @@ -2089,6 +2096,43 @@ def test_scaled_dot_product_fused_attention_with_inf(self, device): actual = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) self.assertEqual(math_ref, actual) + def test_sdpa_backward_with_gradient(self, device): + # https://github.com/pytorch/pytorch/issues/133671. + def sdpa_helper(): + torch.manual_seed(777) + query = ( + torch.empty(size=[2, 2, 49, 32], dtype=torch.float32, device=device) + .uniform_(-1, 1) + .requires_grad_(True) + ) + key = ( + torch.empty(size=[2, 2, 49, 32], dtype=torch.float32, device=device) + .uniform_(-1, 1) + .requires_grad_(True) + ) + value = ( + torch.empty(size=[2, 2, 49, 32], dtype=torch.float32, device=device) + .uniform_(-1, 1) + .requires_grad_(True) + ) + res = torch.nn.functional.scaled_dot_product_attention( + query, key, value, None, 0.0, False + ) + res_grad = ( + torch.empty_like(res, device=device) + .uniform_(-1, 1) + ) + res.backward(res_grad, retain_graph=True) + return res, query.grad, key.grad, value.grad + with sdpa_kernel(backends=[SDPBackend.MATH]): + res_ref, query_grad_ref, key_grad_ref, value_grad_ref = sdpa_helper() + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + res_actual, query_grad_actual, key_grad_actual, value_grad_actual = sdpa_helper() + self.assertEqual(res_ref, res_actual) + self.assertEqual(query_grad_ref, query_grad_actual) + self.assertEqual(key_grad_ref, key_grad_actual) + self.assertEqual(value_grad_ref, value_grad_actual) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("backend", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION]) @parametrize("seq_len", [32, 64, 128]) @@ -2870,6 +2914,8 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, return if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: torch.cuda.empty_cache() # Prevent memory fragmentation + if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: + self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -2917,15 +2963,27 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad) + fudge_factors = { + 'out': 3.0 , + 'grad_query': 150.0 , + 'grad_key': 25.0, + 'grad_value': 8.5, + } + if TEST_WITH_ROCM: + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 160.0 + fudge_factors['grad_query'] = 650.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - 'out': 3.0 , - 'grad_query': 150.0 , - 'grad_key': 25.0, - 'grad_value': 8.5, - } + fudge_factors=fudge_factors, ) @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA") @@ -3014,16 +3072,28 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value, attn_mask), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref, attn_mask_ref), upstream_grad) + fudge_factors = { + "out": 4, + "grad_query": 150.0, + "grad_key": 25.0, + "grad_value": 8.0, + "grad_attn_mask": 45.0, + } + if TEST_WITH_ROCM: + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 160.0 + fudge_factors['grad_query'] = 650.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - "out": 4, - "grad_query": 160.0, - "grad_key": 25.0, - "grad_value": 8.0, - "grad_attn_mask": 45.0, - }, + fudge_factors=fudge_factors, ) @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @@ -3036,7 +3106,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @parametrize("dropout_p", [0.0, 0.22, 0.48]) @parametrize("dtype", [torch.float16, torch.bfloat16]) @parametrize("scale", [None, "l1"]) - @parametrize("enable_gqa", [True, False]) + @parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False]) @parametrize("n_heads", [[16, 8], [10, 2]]) def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, @@ -3124,18 +3194,31 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad) grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad) + fudge_factors = { + 'out': 4, + 'grad_query': 160.0, + 'grad_key': 16, + 'grad_value': 4, + } + if TEST_WITH_ROCM: + fudge_factors['grad_key'] = 45.0 + fudge_factors['grad_query'] = 360.0 + if seq_len_k >= 1024: + fudge_factors['grad_key'] = 70.0 + if seq_len_k >= 2048: + fudge_factors['grad_key'] = 190.0 + fudge_factors['grad_query'] = 650.0 + if seq_len_q >= 2048: + fudge_factors['grad_query'] = 1100.0 + if dtype == torch.float32: + fudge_factors['grad_key'] = 90.0 + check_out_and_grad( (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), - fudge_factors={ - 'out': 4, - 'grad_query': 160.0, - 'grad_key': 16, - 'grad_value': 4, - } + fudge_factors=fudge_factors, ) - @skipIfRocm # FIXME: "capturing stream has unjoined work" @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("batch_size", [1, 8]) @parametrize("seq_len_q", [256, 1024]) @@ -3183,6 +3266,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") + if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: + self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 n_heads = 4 @@ -3279,10 +3364,10 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), fudge_factors={ - 'out': 2.0, + 'out': 3.0, 'grad_query': 100.0, 'grad_key': 8.0, - 'grad_value': 2.0, + 'grad_value': 3.0, } ) @@ -3682,6 +3767,7 @@ def test_causal_variants_compile(self, device, causal_variant: CausalVariant, sh self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") + @skipIfRocm @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): make_tensor = partial( diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 6a1ae247c14320..a4bbb8394da215 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -8,8 +8,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, make_tensor, TEST_NUMPY, set_default_dtype, torch_to_numpy_dtype_dict, - numpy_to_torch_dtype_dict, skipIfTorchDynamo, - xfailIfTorchDynamo) + numpy_to_torch_dtype_dict, skipIfTorchDynamo) from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, dtypes, onlyCPU, expectedFailureMeta, skipMeta) from torch.testing._internal.common_dtype import ( @@ -45,7 +44,6 @@ class TestTypePromotion(TestCase): # Promoting inplace would require re-allocating and copying the memory of the # tensor data, since element size could change. # https://github.com/pytorch/pytorch/issues/127049 - @xfailIfTorchDynamo @float_double_default_dtype def test_inplace(self, device): int_tensor = torch.ones([4, 4, 4], dtype=torch.int32, device=device) diff --git a/test/test_utils_internal.py b/test/test_utils_internal.py new file mode 100644 index 00000000000000..9d0b4d4d57d342 --- /dev/null +++ b/test/test_utils_internal.py @@ -0,0 +1,143 @@ +# Owner(s): ["module: unknown"] + +import os + +from torch._utils_internal import justknobs_feature, JustKnobsConfig +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + load_tests, +) + + +# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for +# sharding on sandcastle. This line silences flake warnings +load_tests = load_tests + +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestJustKnob(TestCase): + def test_justknob_config(self): + with self.subTest("Returns True"): + a = JustKnobsConfig() + self.assertTrue(a.get()) + with self.subTest("Returns False"): + a = JustKnobsConfig(name="fake_name", default=False) + self.assertFalse(a.get()) + with self.subTest("Returns True via config"): + a = JustKnobsConfig(name="fake_name", default=False) + a.set(True) + self.assertTrue(a.get()) + with self.subTest("Returns True via env"): + os.environ["FAKE_FEATURE"] = "1" + a = JustKnobsConfig( + name="fake_name", env_name="FAKE_FEATURE", default=False + ) + self.assertTrue(a.get()) + with self.subTest("Returns same value consistently"): + a = JustKnobsConfig(name="fake_name", default=False) + a.set(True) + self.assertTrue(a.get()) + a.set(False) + self.assertTrue(a.get()) + with self.subTest("Checks __bool__"): + a = JustKnobsConfig(name="fake_name", default=False) + if a: + raise RuntimeError("Should not be true") + self.assertFalse(a) + + def test_justknob_feature(self): + with self.subTest("OSS is True"): + self.assertTrue(justknobs_feature("testname")) + with self.subTest("OSS default=True"): + self.assertTrue(justknobs_feature("testname", default=True)) + with self.subTest("OSS default=False"): + self.assertFalse(justknobs_feature("testname", default=False)) + with self.subTest("OSS config=True, default=False"): + self.assertTrue( + justknobs_feature("testname", config_value=True, default=False) + ) + with self.subTest("OSS config=None, default=False"): + self.assertFalse( + justknobs_feature("testname", config_value=None, default=False) + ) + with self.subTest("OSS config=False, default=True"): + self.assertFalse( + justknobs_feature("testname", config_value=False, default=True) + ) + with self.subTest("OSS env is missing, config=False, default=True"): + self.assertFalse( + justknobs_feature( + "testname", config_value=False, env_name="NOTDEFINED", default=False + ) + ) + with self.subTest("OSS env is missing, default=False"): + self.assertFalse( + justknobs_feature("testname", env_name="NOTDEFINED", default=False) + ) + with self.subTest( + "OSS config overrides env, config=True, env=False, default=False" + ): + os.environ["FEATURE_ENV"] = "0" + self.assertTrue( + justknobs_feature( + "testname", + config_value=True, + env_name="FEATURE_ENV", + default=False, + ) + ) + with self.subTest("OSS env overrides default, , default=False"): + os.environ["FEATURE_ENV"] = "1" + self.assertTrue( + justknobs_feature("testname", env_name="FEATURE_ENV", default=False) + ) + with self.subTest("OSS env truthy, config=False, default=False"): + os.environ["FEATURE_ENV"] = "1" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + os.environ["FEATURE_ENV"] = "true" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + os.environ["FEATURE_ENV"] = "TRUE" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + os.environ["FEATURE_ENV"] = "very weird true" + self.assertTrue( + justknobs_feature( + "testname", + env_name="FEATURE_ENV", + default=False, + ) + ) + with self.subTest("OSS env false, default=True"): + os.environ["FEATURE_ENV"] = "0" + self.assertFalse( + justknobs_feature("testname", env_name="FEATURE_ENV", default=True) + ) + os.environ["FEATURE_ENV"] = "false" + self.assertFalse( + justknobs_feature("testname", env_name="FEATURE_ENV", default=True) + ) + os.environ["FEATURE_ENV"] = "FALSE" + self.assertFalse( + justknobs_feature("testname", env_name="FEATURE_ENV", default=True) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_xpu.py b/test/test_xpu.py index 9dde7d8a71cf3f..471a422ab0b0c7 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -1,6 +1,5 @@ # Owner(s): ["module: intel"] -import collections import subprocess import sys import tempfile @@ -8,7 +7,7 @@ import torch import torch.xpu._gpu_trace as gpu_trace -from torch.testing._internal.autocast_test_lists import AutocastTestLists +from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyXPU, @@ -367,11 +366,58 @@ def test_serialization_array_with_empty(self): self.assertIs(type(copy), type(original)) self.assertEqual(copy.get_device(), original.get_device()) + def test_out_of_memory(self): + tensor = torch.zeros(1024, device="xpu") + + with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"): + torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="xpu") + + with self.assertRaisesRegex(RuntimeError, "XPU out of memory."): + torch.empty(1024 * 1024 * 1024 * 8000000000, dtype=torch.int8, device="xpu") + + def test_raises_oom(self): + torch.xpu.memory.empty_cache() + with self.assertRaises(torch.OutOfMemoryError): + torch.empty(1024 * 1024 * 1024 * 1024, device="xpu") + + def test_memory_allocation(self): + torch.xpu.empty_cache() + prev = torch.xpu.memory_allocated() + a = torch.ones(10, device="xpu") + self.assertGreater(torch.xpu.memory_allocated(), prev) + self.assertGreater(torch.xpu.memory_reserved(), 0) + del a + self.assertEqual(torch.xpu.memory_allocated(), prev) + torch.xpu.empty_cache() + self.assertEqual(torch.xpu.memory_reserved(), 0) + torch.xpu.reset_accumulated_memory_stats() + # Activate 1kB memory + a = torch.randn(256, device="xpu") + # Detect if the current active memory is 1kB + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.current"], 1024) + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.freed"], 0) + del a + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.current"], 0) + self.assertEqual(torch.xpu.memory_stats()["active_bytes.all.freed"], 1024) + + @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") + def test_device_memory_allocated(self): + device_count = torch.xpu.device_count() + current_alloc = [torch.xpu.memory_allocated(idx) for idx in range(device_count)] + x = torch.ones(10, device="xpu:0") + self.assertGreater(torch.xpu.memory_allocated(0), current_alloc[0]) + self.assertTrue( + all( + torch.xpu.memory_allocated(idx) == current_alloc[idx] + for idx in range(1, device_count) + ) + ) + instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True) -class TestXpuAutocast(TestCase): +class TestXpuAutocast(TestAutocast): # These operators are not implemented on XPU backend and we can NOT fall back # them to CPU. So we have to skip them at this moment. # TODO: remove these operators from skip list when they are implemented on XPU backend. @@ -385,89 +431,6 @@ def tearDown(self): del self.autocast_lists super().tearDown() - def _run_autocast_outofplace( - self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None - ): - # helper to cast args - def cast(val, to_type): - if isinstance(val, torch.Tensor): - return val.to(to_type) if val.is_floating_point() else val - elif isinstance(val, collections.abc.Iterable): - return type(val)(cast(v, to_type) for v in val) - else: - return val - - if add_kwargs is None: - add_kwargs = {} - fast_dtype = torch.bfloat16 if run_as_type == torch.bfloat16 else torch.float16 - self.assertFalse(torch.is_autocast_enabled("xpu")) - with torch.amp.autocast("xpu", dtype=fast_dtype): - self.assertTrue(torch.is_autocast_enabled("xpu")) - - out_type = out_type if out_type is not None else run_as_type - output = output_method = None - - # Try module.* variant, if requested: - if module is not None and hasattr(module, op): - output = getattr(module, op)(*args, **add_kwargs) - if isinstance(output, torch.Tensor): - self.assertTrue( - out_type == output.dtype, - f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", - ) - - # Try Tensor.* variant: - if hasattr(torch.Tensor, op): - output_method = getattr(args[0], op)(*args[1:], **add_kwargs) - if isinstance(output_method, torch.Tensor): - self.assertTrue( - out_type == output_method.dtype, - f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", - ) - - self.assertTrue( - (output is not None) or (output_method is not None), - f"{op} not found as an attribute on either Tensor or the requested module {module}", - ) - - # Accounts for ops that return Tensors, iterables, and other non-Tensors. - # For example, lstm_cell returns a tuple and equal returns bool. - def compare(first, second): - if isinstance(first, torch.Tensor): - return torch.equal(first, second) - elif isinstance(first, collections.abc.Iterable): - return all(compare(f, s) for f, s in zip(first, second)) - else: - return first == second - - # If both torch.* and Tensor.* variants were found, check outputs are identical - if (output is not None) and (output_method is not None): - self.assertTrue(type(output) == type(output_method)) - comparison = compare(output, output_method) - self.assertTrue( - comparison, f"torch.{op} result did not match Tensor.{op} result" - ) - - # Compare numerics to Python-side "autocasting" that (we expect) does the same thing - # as the C++-side autocasting, and should be bitwise accurate. - output_to_compare = output if output is not None else output_method - with torch.amp.autocast("xpu", enabled=False): - self.assertFalse(torch.is_autocast_enabled("xpu")) - - if module is not None and hasattr(module, op): - control = getattr(module, op)( - *cast(args, run_as_type), **add_kwargs - ) - else: - control = getattr(args[0].to(run_as_type), op)( - *cast(args[1:], run_as_type), **add_kwargs - ) - self.assertTrue(type(output_to_compare) == type(control)) - comparison = compare(output_to_compare, control) - self.assertTrue(comparison, f"torch.{op} result did not match control") - self.assertTrue(torch.is_autocast_enabled("xpu")) - self.assertFalse(torch.is_autocast_enabled("xpu")) - def test_autocast_torch_fp16(self): for op_with_args in self.autocast_lists.torch_fp16: skip_test = False @@ -477,7 +440,9 @@ def test_autocast_torch_fp16(self): if len(op_with_args) == 3: skip_test = True # skip cudnn op if not skip_test: - self._run_autocast_outofplace(op, args, torch.float16) + self._run_autocast_outofplace( + op, args, torch.float16, device="xpu", amp_dtype=torch.float16 + ) def test_autocast_torch_bf16(self): for op_with_args in self.autocast_lists.torch_fp16: @@ -488,15 +453,24 @@ def test_autocast_torch_bf16(self): if len(op_with_args) == 3: skip_test = True # skip cudnn op if not skip_test: - self._run_autocast_outofplace(op, args, torch.bfloat16) + self._run_autocast_outofplace(op, args, torch.bfloat16, device="xpu") def test_autocast_torch_need_autocast_promote(self): for op, args in self.autocast_lists.torch_need_autocast_promote: - self._run_autocast_outofplace(op, args, torch.float32) + self._run_autocast_outofplace( + op, args, torch.float32, device="xpu", amp_dtype=torch.float16 + ) def test_autocast_torch_expect_builtin_promote(self): for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote: - self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) + self._run_autocast_outofplace( + op, + args, + torch.float32, + device="xpu", + out_type=out_type, + amp_dtype=torch.float16, + ) def test_autocast_checkpointing(self): model = torch.nn.Sequential( diff --git a/test/torch_np/check_tests_conform.py b/test/torch_np/check_tests_conform.py index 05ff5357b7c7c2..c1795bfdc0f8d9 100644 --- a/test/torch_np/check_tests_conform.py +++ b/test/torch_np/check_tests_conform.py @@ -43,7 +43,6 @@ def check(path): else: report_violation(line, num, "off-class parametrize") if not src[nn - 1].startswith("@instantiate_parametrized_tests"): - # breakpoint() report_violation( line, num, f"missing instantiation of parametrized tests in {ln}?" ) diff --git a/test/torch_np/numpy_tests/core/test_einsum.py b/test/torch_np/numpy_tests/core/test_einsum.py index 029e10cead4ab1..5432fb63d18000 100644 --- a/test/torch_np/numpy_tests/core/test_einsum.py +++ b/test/torch_np/numpy_tests/core/test_einsum.py @@ -440,8 +440,6 @@ def check_einsum_sums(self, dtype, do_opt=False): # Suppress the complex warnings for the 'as f8' tests with suppress_warnings() as sup: - # sup.filter(np.ComplexWarning) - # matvec(a,b) / a.dot(b) where a is matrix, b is vector for n in range(1, 17): a = np.arange(4 * n, dtype=dtype).reshape(4, n) diff --git a/test/torch_np/numpy_tests/core/test_indexing.py b/test/torch_np/numpy_tests/core/test_indexing.py index c5823dc30ff37c..d47f62dca505c9 100644 --- a/test/torch_np/numpy_tests/core/test_indexing.py +++ b/test/torch_np/numpy_tests/core/test_indexing.py @@ -602,16 +602,21 @@ def test_boolean_index_cast_assign(self): zero_array[bool_index] = np.array([1]) assert_equal(zero_array[0, 1], 1) + # np.ComplexWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_ns = hasattr(np, "exceptions") + ComplexWarning = ( + np.exceptions.ComplexWarning if has_exceptions_ns else np.ComplexWarning + ) + # Fancy indexing works, although we get a cast warning. assert_warns( - np.ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j]) + ComplexWarning, zero_array.__setitem__, ([0], [1]), np.array([2 + 1j]) ) assert_equal(zero_array[0, 1], 2) # No complex part # Cast complex to float, throwing away the imaginary portion. - assert_warns( - np.ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j]) - ) + assert_warns(ComplexWarning, zero_array.__setitem__, bool_index, np.array([1j])) assert_equal(zero_array[0, 1], 0) @@ -1012,6 +1017,14 @@ def test_multidim(self): # This is so that np.array(True) is not accepted in a full integer # index, when running the file separately. warnings.filterwarnings("error", "", DeprecationWarning) + # np.VisibleDeprecationWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_ns = hasattr(np, "exceptions") + VisibleDeprecationWarning = ( + np.exceptions.VisibleDeprecationWarning + if has_exceptions_ns + else np.VisibleDeprecationWarning + ) warnings.filterwarnings("error", "", np.VisibleDeprecationWarning) def isskip(idx): diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index baed25a9b021c3..f0ef1f79beb542 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -5469,8 +5469,6 @@ def test_out_arg(self): out = np.zeros((5, 2), dtype=np.complex128) c = self.matmul(a, b, out=out) assert_(c is out) - # with suppress_warnings() as sup: - # sup.filter(np.ComplexWarning, '') c = c.astype(tgt.dtype) assert_array_equal(c, tgt) @@ -5852,9 +5850,16 @@ def test_complex_warning(self): x = np.array([1, 2]) y = np.array([1 - 2j, 1 + 2j]) + # np.ComplexWarning moved to np.exceptions in numpy>=2.0.0 + # np.exceptions only available in numpy>=1.25.0 + has_exceptions_ns = hasattr(np, "exceptions") + ComplexWarning = ( + np.exceptions.ComplexWarning if has_exceptions_ns else np.ComplexWarning + ) + with warnings.catch_warnings(): - warnings.simplefilter("error", np.ComplexWarning) - assert_raises(np.ComplexWarning, x.__setitem__, slice(None), y) + warnings.simplefilter("error", ComplexWarning) + assert_raises(ComplexWarning, x.__setitem__, slice(None), y) assert_equal(x, [1, 2]) diff --git a/test/torch_np/numpy_tests/lib/test_arraypad.py b/test/torch_np/numpy_tests/lib/test_arraypad.py index befa9d76ac467a..9b16e105db62f7 100644 --- a/test/torch_np/numpy_tests/lib/test_arraypad.py +++ b/test/torch_np/numpy_tests/lib/test_arraypad.py @@ -543,7 +543,7 @@ def test_check_constant_odd_pad_amount(self): @xpassIfTorchDynamo # (reason="tuple values") def test_check_constant_pad_2d(self): arr = np.arange(4).reshape(2, 2) - test = np.lib.pad( + test = np.pad( arr, ((1, 2), (1, 3)), mode="constant", constant_values=((1, 2), (3, 4)) ) expected = np.array( diff --git a/third_party/XNNPACK b/third_party/XNNPACK index fcbf55af6cf28a..87ee0b46b834f6 160000 --- a/third_party/XNNPACK +++ b/third_party/XNNPACK @@ -1 +1 @@ -Subproject commit fcbf55af6cf28a4627bcd1f703ab7ad843f0f3a2 +Subproject commit 87ee0b46b834f67bad9025d4a82ed5654f3403d3 diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 3c8b1533ac03dd..a5ff6df40ce528 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 3c8b1533ac03dd6531ab6e7b9245d488f13a82a5 +Subproject commit a5ff6df40ce528721cfc310c7ed43946d77404d5 diff --git a/third_party/generate-xnnpack-wrappers.py b/third_party/generate-xnnpack-wrappers.py index 34171f676506cd..ee47aebc85bea9 100755 --- a/third_party/generate-xnnpack-wrappers.py +++ b/third_party/generate-xnnpack-wrappers.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 from __future__ import print_function +from pathlib import Path import collections import os import sys @@ -91,6 +92,11 @@ # add non-prod microkernel sources here: } +# Source files not needed in buck build. +IGNORED_SOURCES = set(( + "\"${PROJECT_BINARY_DIR}/build_identifier.c\"", # Not currently used and requires build-time codegen. +)) + def handle_singleline_parse(line): start_index = line.find("(") end_index = line.find(")") @@ -99,12 +105,24 @@ def handle_singleline_parse(line): return key_val[0], [x[4:] for x in key_val[1:]] def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"): + print(f"Updating sources from {cmakefile}") sources = collections.defaultdict(list) with open(os.path.join(xnnpack_path, cmakefile)) as cmake: lines = cmake.readlines() i = 0 while i < len(lines): line = lines[i] + + if lines[i].startswith("INCLUDE"): + file, _ = handle_singleline_parse(line) + if file.startswith("cmake/gen/"): + path = Path(xnnpack_path) / "XNNPACK" / file + local_sources = update_sources(xnnpack_path, path.absolute().as_posix()) + for k,v in local_sources.items(): + if k in sources: + sources[k] = sources[k] + local_sources[k] + else: + sources[k] = local_sources[k] if lines[i].startswith("SET") and "src/" in lines[i]: name, val = handle_singleline_parse(line) @@ -118,12 +136,14 @@ def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"): while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]: # remove "src/" at the beginning, remove whitespaces and newline value = lines[i].strip(' \t\n\r') - sources[name].append(value[4:]) + if value not in IGNORED_SOURCES: + sources[name].append(value[4:]) i += 1 if i < len(lines) and len(lines[i]) > 4: # remove "src/" at the beginning, possibly ')' at the end value = lines[i].strip(' \t\n\r)') - sources[name].append(value[4:]) + if value not in IGNORED_SOURCES: + sources[name].append(value[4:]) else: i += 1 return sources @@ -132,7 +152,7 @@ def gen_wrappers(xnnpack_path): xnnpack_sources = collections.defaultdict(list) sources = update_sources(xnnpack_path) - microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake") + microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/gen/microkernels.cmake") for key in microkernels_sources: sources[key] = microkernels_sources[key] @@ -186,6 +206,8 @@ def gen_wrappers(xnnpack_path): def main(argv): + print("Generating wrappers...") + if argv is None or len(argv) == 0: gen_wrappers(".") else: diff --git a/third_party/ideep b/third_party/ideep index 383e9238c1c118..41d636c2bbcea6 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 383e9238c1c118c1ff72662f9ed7c8610150325d +Subproject commit 41d636c2bbcea6bff0faf97cdb65a48cdde987af diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index 64154894d65e69..aa55943bc913c9 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -11,9 +11,9 @@ _DNNL_RUNTIME_OMP = { "#cmakedefine DNNL_SYCL_CUDA": "/* #undef DNNL_SYCL_CUDA */", "#cmakedefine DNNL_SYCL_HIP": "/* #undef DNNL_SYCL_HIP */", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", + "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "/* undef DNNL_EXPERIMENTAL_UKERNEL */", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE", - "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#undef DNNL_EXPERIMENTAL_UKERNEL", "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", @@ -138,6 +138,7 @@ cc_library( "DNNL_ENABLE_CONCURRENT_EXEC", "DNNL_ENABLE_PRIMITIVE_CACHE", "DNNL_ENABLE_CPU_ISA_HINTS", + "DNNL_EXPERIMENTAL_UKERNEL", "ONEDNN_BUILD_GRAPH", ], ) diff --git a/third_party/pybind11 b/third_party/pybind11 index 941f45bcb51457..a2e59f0e706540 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 941f45bcb51457884fa1afd6e24a67377d70f75c +Subproject commit a2e59f0e7065404b44dfe92a28aca47ba1378dc4 diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index cb351261d4039c..1ca8e776de683b 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -4,7 +4,6 @@ load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") load( ":xnnpack_src_defs.bzl", - "JIT_SRCS", "LOGGING_SRCS", "OPERATOR_SRCS", "SUBGRAPH_SRCS", @@ -20,8 +19,8 @@ load( "PROD_AVX512F_MICROKERNEL_SRCS", "PROD_AVX512SKX_MICROKERNEL_SRCS", "PROD_AVX512VBMI_MICROKERNEL_SRCS", - "PROD_AVX512VNNI_MICROKERNEL_SRCS", "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS", + "PROD_AVX512VNNI_MICROKERNEL_SRCS", "PROD_AVXVNNI_MICROKERNEL_SRCS", "PROD_AVX_MICROKERNEL_SRCS", "PROD_F16C_MICROKERNEL_SRCS", @@ -47,6 +46,12 @@ load( "PROD_XOP_MICROKERNEL_SRCS", ) +XNN_COMMON_PREPROCESSOR_FLAGS = [ + "-DXNN_PRIVATE=", + "-DXNN_INTERNAL=", + "-DXNN_LOG_LEVEL=0" +] + # This defines XNNPACK targets for both fbsource BUCK and OSS BUCK # Note that the file path is relative to the BUCK file that called from, not to this bzl file. # So for fbsource build it points to xplat/third-party/XNNPACK/XNNPACK, @@ -79,9 +84,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }, apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], exported_deps = [ # Dependency only on pthreadpool interface @@ -100,15 +103,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - "-DXNN_ENABLE_JIT=0", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS + [ "-DXNN_ENABLE_SPARSE=0", "-DXNN_ENABLE_MEMOPT", ], @@ -134,15 +132,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -154,37 +147,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ) - fb_xplat_cxx_library( - name = "jit_memory", - # srcs have to include HOT_SRCS to be able to build on ARVR - srcs = JIT_SRCS, - headers = subdir_glob([ - ("XNNPACK/src", "**/*.h"), - ]), - header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), - compiler_flags = [ - "-Oz", - ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], - labels = labels, - platforms = (APPLE, ANDROID, CXX, WINDOWS), - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], - visibility = ["PUBLIC"], - windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, - windows_compiler_flags_override = WINDOWS_FLAGS, - deps = [ - ":interface", - third_party("clog"), - ], - ) - fb_xplat_cxx_library( name = "ukernels_scalar", srcs = PROD_SCALAR_MICROKERNEL_SRCS, @@ -200,15 +162,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-fno-math-errno", "-ffp-contract=off", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -231,10 +188,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -250,10 +203,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSE_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse"], @@ -273,10 +225,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -286,10 +234,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse"], @@ -311,10 +258,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -330,10 +273,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSE2_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse2"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse2"], @@ -354,10 +296,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -367,10 +305,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse2"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse2"], @@ -393,10 +330,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -412,10 +345,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSSE3_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mssse3"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mssse3"], @@ -436,10 +368,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -449,10 +377,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mssse3"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mssse3"], @@ -475,10 +402,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -494,10 +417,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_SSE41_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse4.1"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse4.1"], @@ -518,10 +440,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -531,10 +449,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-msse4.1"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-msse4.1"], @@ -565,10 +482,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -584,10 +497,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -608,10 +520,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", "-mavx", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -621,10 +529,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -664,10 +571,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512dq", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -690,9 +593,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -712,10 +613,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -732,9 +629,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -758,16 +653,22 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ] + select({ "DEFAULT": [], "ovr_config//cpu:x86_32": [ - "-mavx", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", ], "ovr_config//cpu:x86_64": [ - "-mavx", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -789,8 +690,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, + exported_preprocessor_flags = [ + "-DXNN_ENABLE_AVX512VNNI" ], visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], @@ -811,10 +713,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -830,8 +728,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, + exported_preprocessor_flags = [ + "-DXNN_ENABLE_AVX512VNNI" ], visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], @@ -857,10 +756,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", "-mfma", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -880,9 +775,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -902,10 +795,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -917,9 +806,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx"], @@ -949,10 +836,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -969,10 +852,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else []), platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mf16c"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mf16c"], @@ -993,10 +875,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", "-mf16c", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1007,10 +885,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mf16c"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mf16c"], @@ -1020,103 +897,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ) - fb_xplat_cxx_library( - name = "ukernels_xop", - srcs = PROD_XOP_MICROKERNEL_SRCS if is_arvr_mode() else [], - headers = subdir_glob([ - ("XNNPACK/src", "**/*.h"), - ("XNNPACK/src", "**/*.c"), - ]), - header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), - compiler_flags = [ - "-O2", - ] + select({ - "DEFAULT": [], - "ovr_config//cpu:x86_32": [ - "-mxop", - ], - "ovr_config//cpu:x86_64": [ - "-mxop", - ], - }), - platform_compiler_flags = [ - ( - "x86|x86_64|platform009|platform010", - [ - "-mxop", - ], - ), - ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], - labels = labels, - platform_preprocessor_flags = [ - ( - "windows-x86_64", - [ - "-Drestrict=", - ], - ), - ], - platform_srcs = ([ - ( - "x86|x86_64|platform009|platform010", - PROD_XOP_MICROKERNEL_SRCS, - ), - ] if not is_arvr_mode() else []), - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], - visibility = ["PUBLIC"], - windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mxop"], - windows_compiler_flags_override = WINDOWS_FLAGS + ["-mxop"], - deps = [ - ":interface", - ], - ) - - fb_xplat_cxx_library( - name = "ukernels_xop_ovr_win32", - headers = subdir_glob([ - ("XNNPACK/src", "**/*.h"), - ("XNNPACK/src", "**/*.c"), - ]), - header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), - compiler_flags = [ - "-O2", - "-mxop", - ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], - labels = labels, - platform_preprocessor_flags = [ - ( - "windows-x86_64", - [ - "-Drestrict=", - ], - ), - ], - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], - visibility = ["PUBLIC"], - windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mxop"], - windows_compiler_flags_override = WINDOWS_FLAGS + ["-mxop"], - windows_srcs = PROD_XOP_MICROKERNEL_SRCS, - deps = [ - ":interface", - ], - ) - fb_xplat_cxx_library( name = "ukernels_fma3", srcs = PROD_FMA3_MICROKERNEL_SRCS if is_arvr_mode() else [], @@ -1139,10 +919,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1159,10 +935,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_FMA3_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mfma", @@ -1190,10 +965,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfma", "-mf16c", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1204,10 +975,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mfma", @@ -1247,10 +1017,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1268,10 +1034,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX2_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx2", @@ -1302,10 +1067,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfma", "-mf16c", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1317,10 +1078,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx2", @@ -1328,6 +1088,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mf16c", ], windows_compiler_flags_override = WINDOWS_FLAGS + [ + "/D__AVX2__", "-mavx2", "-mfma", "-mf16c", @@ -1358,10 +1119,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512f", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1377,10 +1134,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX512F_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx512f"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx512f"], @@ -1419,10 +1175,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512vbmi", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1443,10 +1195,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX512VBMI_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx512f", @@ -1481,10 +1232,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-O2", "-mavx512f", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1494,10 +1241,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + ["-mavx512f"], windows_compiler_flags_override = WINDOWS_FLAGS + ["-mavx512f"], @@ -1535,10 +1281,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512vl", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1558,10 +1300,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_AVX512SKX_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else []), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx512f", @@ -1576,6 +1317,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512bw", "-mavx512dq", "-mavx512vl", + ], deps = [ ":interface", @@ -1598,10 +1340,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512dq", "-mavx512vl", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1615,10 +1353,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [ "-mavx512f", @@ -1633,6 +1370,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mavx512bw", "-mavx512dq", "-mavx512vl", + "/D__AVX512BW__", ], windows_srcs = PROD_AVX512SKX_MICROKERNEL_SRCS, deps = [ @@ -1654,10 +1392,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-fno-fast-math", "-fno-math-errno", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1670,10 +1404,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1705,10 +1438,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1726,10 +1455,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEON_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1754,10 +1482,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], platform_srcs = [ ( "(aarch64|arm64)", @@ -1765,10 +1489,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else [], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1800,10 +1523,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-vfpv4", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1828,10 +1547,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONFMA_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1856,10 +1574,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F compiler_flags = [ "-O2", ], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_srcs = [ ( @@ -1868,10 +1582,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else [], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1908,10 +1621,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-ffinite-math-only", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1928,10 +1637,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -1959,10 +1667,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-fp16", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -1974,10 +1678,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2001,10 +1704,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "DEFAULT": [], "ovr_config//cpu:arm64": ["-march=armv8-a"], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2029,10 +1728,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2064,10 +1762,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfloat-abi=softfp", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2085,10 +1779,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONDOT_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2116,10 +1809,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "DEFAULT": [], "ovr_config//cpu:arm64": ["-march=armv8.2-a+dotprod"], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2135,10 +1824,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONDOT_MICROKERNEL_SRCS + PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2186,15 +1874,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2224,10 +1907,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-march=armv8.2-a+dotprod+fp16", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], platform_compiler_flags = [ ( "(aarch64|arm64)", @@ -2243,10 +1922,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ] if not is_arvr_mode() else [], labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2278,10 +1956,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-fp-armv8", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2299,10 +1973,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONFP16ARITH_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2330,10 +2003,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "DEFAULT": [], "ovr_config//cpu:arm64": ["-march=armv8.2-a+fp16"], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2349,10 +2018,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F PROD_NEONFP16ARITH_MICROKERNEL_SRCS + PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS, ), ] if not is_arvr_mode() else [], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2384,10 +2052,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-march=armv8.2-a+i8mm+fp16", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2406,10 +2070,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, @@ -2438,10 +2101,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-mfpu=neon-fp-armv8", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2454,16 +2113,14 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ), ], platforms = (APPLE, ANDROID, CXX, WINDOWS), + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, deps = [ ":interface", - ":jit_memory", third_party("FP16"), ], ) @@ -2485,10 +2142,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-march=armv8.2-a+fp16+dotprod", ], }), - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], labels = labels, platform_compiler_flags = [ ( @@ -2498,16 +2151,14 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ], ), ], + fbandroid_link_whole = True, preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", - ], + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS, windows_compiler_flags_override = WINDOWS_FLAGS, deps = [ ":interface", - ":jit_memory", third_party("FP16"), ], ) @@ -2516,10 +2167,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "arm64_lib", apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ - ":jit_memory", ":ukernels_asm_aarch64", ":ukernels_neon", ":ukernels_neon_aarch64", @@ -2554,7 +2205,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ":ukernels_sse2", ":ukernels_sse41", ":ukernels_ssse3", - ":ukernels_xop", ":ukernels_avx512vbmi", ":ukernels_avx512vnni", ":ukernels_avx512vnnigfni", @@ -2579,12 +2229,14 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ":ukernels_sse41_ovr_win32", ":ukernels_sse_ovr_win32", ":ukernels_ssse3_ovr_win32", - ":ukernels_xop_ovr_win32", ":ukernels_avx512vbmi", - ":ukernels_avx512vnni_ovr_win32", - ":ukernels_avx512vnnigfni_ovr_win32", + # ":ukernels_avx512vnni_ovr_win32", # Build crashes on Windows Clang 17.0.3, re-enable when fixed (T199959765) + # ":ukernels_avx512vnnigfni_ovr_win32", # ":ukernels_avxvnni_ovr_win32" Excluding avxvnni microkernels because they fail on older compilers ], + exported_preprocessor_flags = [ + "-DXNN_ENABLE_AVX512VNNIGFNI=0" + ] ) fb_xplat_cxx_library( @@ -2594,7 +2246,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ - ":jit_memory", ":ukernels_armsimd32", ":ukernels_asm_aarch32", ":ukernels_asm_aarch64", @@ -2619,10 +2270,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "armv7_lib", apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ - ":jit_memory", ":ukernels_asm_aarch32", ":ukernels_neon", ":ukernels_neon_dot", @@ -2635,6 +2286,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "prod_ukernels", apple_sdks = (IOS, MACOSX, APPLETVOS), labels = labels, + fbandroid_link_whole = True, preferred_linkage = "static", visibility = ["PUBLIC"], deps = [ @@ -2668,19 +2320,13 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_headers = { "xnnpack.h": "XNNPACK/include/xnnpack.h", }, - fbobjc_preprocessor_flags = [ - "-DXNN_PRIVATE=", - "-DXNN_INTERNAL=", - ], header_namespace = "", headers = subdir_glob([ ("XNNPACK/src", "**/*.h"), ("XNNPACK/include", "**/*.h"), ]), platforms = (APPLE, ANDROID, CXX, WINDOWS), - preferred_linkage = "static", - preprocessor_flags = [ - "-DXNN_LOG_LEVEL=0", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS + [ "-DXNN_NO_Q8_OPERATORS", "-DXNN_NO_F16_OPERATORS", "-DXNN_NO_NCHW_OPERATORS", @@ -2690,7 +2336,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "-DXNN_NO_X8_OPERATORS", "-DXNN_ENABLE_MEMOPT", "-DXNN_ENABLE_SPARSE=0", - "-DXNN_ENABLE_JIT=0", "-DXNN_ENABLE_ASSEMBLY", "-DXNN_ENABLE_GEMM_M_SPECIALIZATION", "-DXNN_ENABLE_ARM_DOTPROD", @@ -2712,7 +2357,6 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "XNNPACK/src/memory.c", "XNNPACK/src/mutex.c", "XNNPACK/src/microparams-init.c", - "XNNPACK/src/operators/post-operation.c", ], visibility = ["PUBLIC"], windows_clang_compiler_flags_override = (WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS) if XNNPACK_WINDOWS_AVX512F_ENABLED else WINDOWS_FLAGS, diff --git a/third_party/xnnpack_src_defs.bzl b/third_party/xnnpack_src_defs.bzl index 296dacb58ec4c1..e9b3b6f9a9cef8 100644 --- a/third_party/xnnpack_src_defs.bzl +++ b/third_party/xnnpack_src_defs.bzl @@ -2,16 +2,12 @@ Auto-generated by generate-wrappers.py script. Do not modify """ -PROD_SCALAR_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/scalar.c", -] - -PROD_AVX512VNNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vnni.c", +PROD_ARMSIMD32_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/armsimd32.c", ] -PROD_AVX512F_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512f.c", +PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16arith-aarch64.c", ] AARCH64_ASM_MICROKERNEL_SRCS = [ @@ -240,152 +236,22 @@ AARCH64_ASM_MICROKERNEL_SRCS = [ "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", ] -PROD_NEONV8_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonv8.c", -] - -PROD_AVX_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx.c", -] - -LOGGING_SRCS = [ - "XNNPACK/src/enums/datatype-strings.c", - "XNNPACK/src/enums/microkernel-type.c", - "XNNPACK/src/enums/node-type.c", - "XNNPACK/src/enums/operator-type.c", - "XNNPACK/src/log.c", -] - -PROD_NEONI8MM_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neoni8mm.c", -] - -AARCH32_ASM_MICROKERNEL_SRCS = [ - "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x1.S", - "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x2.S", - "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x4.S", - "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x1.S", - "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x4.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-asm-aarch32-vfp-ld64.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-minmax-asm-aarch32-vfp-ld64.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", - "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-ld64.S", - "XNNPACK/src/f32-igemm/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", - "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-ld64.S", - "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", - "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", - "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p8c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p16c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", - "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", - "XNNPACK/src/qs16-qs8-vcvt/qs16-qs8-vcvt-asm-aarch32-neon-u16.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", - "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", - "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-arm-x1.S", - "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x1.S", - "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x2.S", -] - -PROD_F16C_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/f16c.c", -] - -PROD_XOP_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/xop.c", -] - -PROD_RVV_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/rvv.c", +PROD_AVXVNNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avxvnni.c", ] SUBGRAPH_SRCS = [ @@ -404,25 +270,30 @@ SUBGRAPH_SRCS = [ "XNNPACK/src/subgraph/convert.c", "XNNPACK/src/subgraph/convolution-2d.c", "XNNPACK/src/subgraph/copy.c", + "XNNPACK/src/subgraph/copysign.c", "XNNPACK/src/subgraph/deconvolution-2d.c", "XNNPACK/src/subgraph/depth-to-space-2d.c", "XNNPACK/src/subgraph/depthwise-convolution-2d.c", "XNNPACK/src/subgraph/divide.c", "XNNPACK/src/subgraph/elu.c", "XNNPACK/src/subgraph/even-split.c", + "XNNPACK/src/subgraph/exp.c", "XNNPACK/src/subgraph/floor.c", "XNNPACK/src/subgraph/fully-connected-sparse.c", "XNNPACK/src/subgraph/fully-connected.c", + "XNNPACK/src/subgraph/gelu.c", "XNNPACK/src/subgraph/global-average-pooling.c", "XNNPACK/src/subgraph/global-sum-pooling.c", "XNNPACK/src/subgraph/hardswish.c", "XNNPACK/src/subgraph/leaky-relu.c", + "XNNPACK/src/subgraph/log.c", "XNNPACK/src/subgraph/max-pooling-2d.c", "XNNPACK/src/subgraph/maximum2.c", "XNNPACK/src/subgraph/minimum2.c", "XNNPACK/src/subgraph/multiply2.c", "XNNPACK/src/subgraph/negate.c", "XNNPACK/src/subgraph/prelu.c", + "XNNPACK/src/subgraph/reciprocal-square-root.c", "XNNPACK/src/subgraph/reshape-helpers.c", "XNNPACK/src/subgraph/scaled-dot-product-attention.c", "XNNPACK/src/subgraph/sigmoid.c", @@ -444,26 +315,40 @@ SUBGRAPH_SRCS = [ "XNNPACK/src/tensor.c", ] -PROD_FMA3_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fma3.c", +PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vnnigfni.c", ] -PROD_AVX512SKX_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512skx.c", +PROD_AVX512VNNI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vnni.c", ] -JIT_SRCS = [ - "XNNPACK/src/jit/aarch32-assembler.cc", - "XNNPACK/src/jit/aarch64-assembler.cc", - "XNNPACK/src/jit/assembler.cc", +PROD_SSE2_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse2.c", ] -PROD_NEONFP16_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16.c", +PROD_NEONDOT_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neondot.c", ] -PROD_SSSE3_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/ssse3.c", +PROD_SSE41_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse41.c", +] + +PROD_SSE_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/sse.c", +] + +PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16arith.c", +] + +PROD_NEONV8_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonv8.c", +] + +PROD_NEONFP16_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfp16.c", ] XNNPACK_SRCS = [ @@ -500,69 +385,23 @@ XNNPACK_SRCS = [ "XNNPACK/src/params.c", ] -PROD_FP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fp16arith.c", -] - -TABLE_SRCS = [ - "XNNPACK/src/tables/exp2-k-over-64.c", - "XNNPACK/src/tables/exp2-k-over-2048.c", - "XNNPACK/src/tables/exp2minus-k-over-4.c", - "XNNPACK/src/tables/exp2minus-k-over-8.c", - "XNNPACK/src/tables/exp2minus-k-over-16.c", - "XNNPACK/src/tables/exp2minus-k-over-32.c", - "XNNPACK/src/tables/exp2minus-k-over-64.c", - "XNNPACK/src/tables/exp2minus-k-over-2048.c", - "XNNPACK/src/tables/vlog.c", -] - -PROD_NEON_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neon.c", -] - -PROD_AVXVNNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avxvnni.c", -] - -PROD_NEONFP16ARITH_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16arith.c", -] - -PROD_SSE_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse.c", +PROD_AVX_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx.c", ] -PROD_NEON_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neon-aarch64.c", - "XNNPACK/src/amalgam/gen/neonfma-aarch64.c", +PROD_AVX512SKX_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512skx.c", ] PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neondotfp16-aarch64.c", ] -PROD_NEONFMA_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfma.c", +PROD_FP16ARITH_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fp16arith.c", ] PROD_FMA_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/fma.c", -] - -PROD_SSE2_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse2.c", -] - -PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vnnigfni.c", -] - -PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neonfp16arith-aarch64.c", -] - -PROD_AVX2_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx2.c", ] OPERATOR_SRCS = [ @@ -595,26 +434,176 @@ OPERATOR_SRCS = [ "XNNPACK/src/operators/unpooling-nhwc.c", ] -PROD_AVX512VBMI_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/avx512vbmi.c", +PROD_NEONI8MM_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neoni8mm.c", ] -PROD_NEONDOT_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/neondot.c", +PROD_AVX512F_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512f.c", +] + +JIT_SRCS = [ +] + +PROD_F16C_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/f16c.c", +] + +PROD_NEON_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neon.c", +] + +PROD_SCALAR_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/scalar.c", ] PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neondot-aarch64.c", ] -PROD_SSE41_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/sse41.c", +PROD_FMA3_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/fma3.c", ] -PROD_ARMSIMD32_MICROKERNEL_SRCS = [ - "XNNPACK/src/amalgam/gen/armsimd32.c", +LOGGING_SRCS = [ + "XNNPACK/src/enums/allocation-type.c", + "XNNPACK/src/enums/datatype-strings.c", + "XNNPACK/src/enums/microkernel-type.c", + "XNNPACK/src/enums/node-type.c", + "XNNPACK/src/enums/operator-type.c", + "XNNPACK/src/log.c", +] + +PROD_NEONFMA_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neonfma.c", +] + +PROD_AVX2_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx2.c", +] + +PROD_AVX512VBMI_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/avx512vbmi.c", +] + +PROD_RVV_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/rvv.c", ] PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS = [ "XNNPACK/src/amalgam/gen/neondotfp16arith.c", ] + +PROD_XOP_MICROKERNEL_SRCS = [ +] + +AARCH32_ASM_MICROKERNEL_SRCS = [ + "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x1.S", + "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x2.S", + "XNNPACK/src/cs16-bfly4/cs16-bfly4-samples1-asm-aarch32-neon-x4.S", + "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x1.S", + "XNNPACK/src/cs16-fftr/cs16-fftr-asm-aarch32-neon-x4.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-asm-aarch32-vfp-ld64.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x4-minmax-asm-aarch32-vfp-ld64.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", + "XNNPACK/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch32-neon-ld64.S", + "XNNPACK/src/f32-igemm/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a55.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-1x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a7.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53-prfm.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a53.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75-prfm.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-cortex-a75.S", + "XNNPACK/src/f32-igemm/gen/f32-igemm-4x8-minmax-asm-aarch32-neon-ld64.S", + "XNNPACK/src/qd8-f16-qc8w-gemm/gen/qd8-f16-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", + "XNNPACK/src/qd8-f16-qc8w-igemm/gen/qd8-f16-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondotfp16arith-cortex-a55.S", + "XNNPACK/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-4x8c4-minmax-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p8c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-dwconv/qs8-qc8w-dwconv-3p16c-minmax-fp32-asm-aarch32-neonv8-mla8-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-1x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a35.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-cortex-a53.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8-minmax-fp32-asm-aarch32-neonv8-mlal-lane-ld64.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-cortex-a55.S", + "XNNPACK/src/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x8c4-minmax-fp32-asm-aarch32-neondot-ld64.S", + "XNNPACK/src/qs16-qs8-vcvt/qs16-qs8-vcvt-asm-aarch32-neon-u16.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qu8-gemm/gen/qu8-gemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-1x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a7.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-cortex-a53.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64-prfm.S", + "XNNPACK/src/qu8-igemm/gen/qu8-igemm-4x8-minmax-rndnu-asm-aarch32-neon-mlal-lane-ld64.S", + "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-arm-x1.S", + "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x1.S", + "XNNPACK/src/u32-filterbank-accumulate/u32-filterbank-accumulate-asm-aarch32-neon-x2.S", +] + +PROD_SSSE3_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/ssse3.c", +] + +TABLE_SRCS = [ + "XNNPACK/src/tables/exp2-k-over-64.c", + "XNNPACK/src/tables/exp2-k-over-2048.c", + "XNNPACK/src/tables/exp2minus-k-over-4.c", + "XNNPACK/src/tables/exp2minus-k-over-8.c", + "XNNPACK/src/tables/exp2minus-k-over-16.c", + "XNNPACK/src/tables/exp2minus-k-over-32.c", + "XNNPACK/src/tables/exp2minus-k-over-64.c", + "XNNPACK/src/tables/exp2minus-k-over-2048.c", + "XNNPACK/src/tables/vlog.c", +] + +PROD_NEON_AARCH64_MICROKERNEL_SRCS = [ + "XNNPACK/src/amalgam/gen/neon-aarch64.c", + "XNNPACK/src/amalgam/gen/neonfma-aarch64.c", +] diff --git a/third_party/xnnpack_wrapper_defs.bzl b/third_party/xnnpack_wrapper_defs.bzl index b92ebb88d74efc..b05cdcd5cdec15 100644 --- a/third_party/xnnpack_wrapper_defs.bzl +++ b/third_party/xnnpack_wrapper_defs.bzl @@ -7,7 +7,6 @@ PROD_SCALAR_MICROKERNEL_SRCS = [ ] PROD_FMA_MICROKERNEL_SRCS = [ - "xnnpack_wrappers/amalgam/gen/fma.c", ] PROD_ARMSIMD32_MICROKERNEL_SRCS = [ @@ -92,7 +91,6 @@ PROD_F16C_MICROKERNEL_SRCS = [ ] PROD_XOP_MICROKERNEL_SRCS = [ - "xnnpack_wrappers/amalgam/gen/xop.c", ] PROD_FMA3_MICROKERNEL_SRCS = [ @@ -447,28 +445,16 @@ AARCH64_ASM_MICROKERNEL_SRCS = [ "xnnpack_wrappers/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", "xnnpack_wrappers/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld64.S", "xnnpack_wrappers/qs8-qc8w-igemm/gen/qs8-qc8w-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-gemm/gen/qu8-gemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x8c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53-prfm.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a53.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75-prfm.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-cortex-a75.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64-prfm.S", "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16-minmax-rndnu-asm-aarch64-neon-mlal-lane-ld64.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-fp32-asm-aarch64-neondot-ld128.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-cortex-a55.S", - "xnnpack_wrappers/qu8-igemm/gen/qu8-igemm-4x16c4-minmax-rndnu-asm-aarch64-neondot-ld128.S", ] diff --git a/third_party/xpu.txt b/third_party/xpu.txt index c34e261f0e73c4..69606f14a7af3a 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -7eb52196954e5eaf32f791252506b0640f16f200 +7e3d00acea9f0d3728048a5b2743de20d55c64ba diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 967f63ae2c8f15..60a1be73fbb2d1 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -207,7 +207,7 @@ def remove_hcc(line: str) -> str: ignores=ignores, extra_files=[ "torch/_inductor/codegen/cpp_wrapper_cpu.py", - "torch/_inductor/codegen/cpp_wrapper_cuda.py", + "torch/_inductor/codegen/cpp_wrapper_gpu.py", "torch/_inductor/codegen/wrapper.py", ], out_of_place_only=args.out_of_place_only, diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 96c55c666b678c..9df4d965d9f788 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -278,6 +278,7 @@ - name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor theta: affine_grid_generator_backward_symint(grad, size, align_corners) + result: auto_linear - name: alias(Tensor(a) self) -> Tensor(a) self: grad @@ -2825,9 +2826,14 @@ cpu_nested_shape_example: non_differentiable - name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor - self: at::_nested_from_padded(grad, self._nested_tensor_size()) + self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt, std::optional(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())" padding: non_differentiable +- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef(padded.sym_sizes())) + offsets: non_differentiable + dummy: non_differentiable + - name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) self: grad.values() nested_size: non_differentiable @@ -2871,7 +2877,7 @@ output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) -- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 7b908425cf8660..fa6c578dea04a2 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -197,6 +197,7 @@ "nanmean", "nansum", "transpose", + "transpose_copy", "permute", "squeeze", "unsqueeze", @@ -259,6 +260,7 @@ "log1p", "log2", "logaddexp", + "logsumexp", "logcumsumexp", "reciprocal", "tan", diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index fab57c25ac9c2d..16c3b9e5efd6a6 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -999,9 +999,6 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg auto opt_memory_format = std::get<4>(parsed); auto& self_ = THPVariable_Unpack(self); torch::utils::maybe_initialize_device(device); - if (device && device->is_privateuseone()) { - at::globalContext().lazyInitPrivateUse1(); - } if (!device && !scalarType && !copy && !opt_memory_format.has_value()) { Py_INCREF(self); return self; @@ -1081,9 +1078,6 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa device = at::Device(device_type); } torch::utils::maybe_initialize_device(device); - if (device.is_privateuseone()) { - at::globalContext().lazyInitPrivateUse1(); - } return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format)); END_HANDLE_TH_ERRORS } diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index a0307c3b502870..37a9534860f189 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -20,11 +20,13 @@ Traceback, ) from tools.flight_recorder.components.utils import ( + align_trace_from_beginning, check_no_missing_dump_files, check_size_alltoall, - check_trace_from_beginning, check_version, find_coalesced_group, + format_frames, + get_version_detail, just_print_entries, match_coalesced_groups, match_one_event, @@ -48,7 +50,13 @@ def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc] def build_groups_memberships( pg_config: Any, -) -> Tuple[List[Group], Dict[Any, Group], List[Membership], Dict[str, Set[Any]]]: +) -> Tuple[ + List[Group], + Dict[Any, Group], + List[Membership], + Dict[str, Set[Any]], + Dict[Tuple[str, int], str], +]: """ pg_config: { global_rank: { @@ -70,6 +78,7 @@ def build_groups_memberships( `_groups`: a dict that is indexed by pg_guid with Group namedtuple as value. `memberships`: a membership table where each row is a Membership namedtuple. `_memberships`: a dict that is indexed by pg_guid with set of ranks (int) as value. + `_pg_guids`: a dict that is indexed by (pg_uid, global_rank) with pg_guid as value. """ # flat lists for return groups = [] @@ -78,10 +87,16 @@ def build_groups_memberships( # dicts for faster cross-rank validation _groups = {} _memberships = {} + _pg_guids = {} for global_rank in pg_config: - for pg_guid in pg_config[global_rank]: - desc = pg_config[global_rank][pg_guid]["desc"] - ranks = ast.literal_eval(pg_config[global_rank][pg_guid]["ranks"]) + for pg_uid in pg_config[global_rank]: + desc = pg_config[global_rank][pg_uid]["desc"] + ranks = ast.literal_eval(pg_config[global_rank][pg_uid]["ranks"]) + # With the adoption of the split_group API, we can have multiple PGs with the same pg_guid (PG Name) + # So we need to add the hash of all its ranks within the PG as well. + # Also guid must be a string because `_process_group_name` returns a string. + pg_guid = pg_uid + str(hash(frozenset(ranks))) + _pg_guids[(pg_uid, global_rank)] = pg_guid if isinstance(ranks, str): # TODO Bug in FR data format? ranks is '[0, 1,...]' ranks = eval(ranks) @@ -96,18 +111,18 @@ def build_groups_memberships( # validation across ranks assert ( _groups[pg_guid].desc == desc - ), f"mismatch in desc {_groups[pg_guid].desc} vs {desc}" + ), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}" assert _memberships[pg_guid] == set( ranks - ), f"mismatch in membership {_memberships[pg_guid]} vs {set(ranks)}" - return groups, _groups, memberships, _memberships + ), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}" + return groups, _groups, memberships, _memberships, _pg_guids def build_nccl_call( entry: Dict[Any, Any], id: int, collective_id: Any, - group_id: int, + group_id: str, global_rank: Any, ) -> NCCLCall: return NCCLCall( @@ -122,9 +137,11 @@ def build_nccl_call( def build_collectives( - all_entries: Dict[str, List[Dict[str, Any]]], + all_entries: Dict[int, List[Dict[str, Any]]], _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], + _pg_guids: Dict[Tuple[str, int], str], + version: str, ) -> Tuple[List[Traceback], List[Collective], List[NCCLCall]]: """ groups, memberships are the non-flat dicts that are indexable @@ -154,6 +171,7 @@ def build_collectives( ] } """ + major_v, minor_v = get_version_detail(version) tracebacks: List[Traceback] = [] collectives: List[Collective] = [] @@ -185,14 +203,20 @@ def build_collectives( entries = all_entries[first_rank] pg_name, desc = entries[0]["process_group"] profiling_name = entries[0]["profiling_name"] + pg_name = _pg_guids[(pg_name, first_rank)] collective_seq_id = entries[0]["collective_seq_id"] + record_id = entries[0]["record_id"] + input_sizes = entries[0]["input_sizes"] + output_sizes = entries[0]["output_sizes"] + collective_state = entries[0]["state"] + collective_frames = format_frames(entries[0]["frames"]) expected_ranks = set(_memberships[pg_name]) candidate_ranks = {first_rank} candidate_idx = {} found_ranks = set() found_idx = {} - if find_coalesced_group(pg_name, entries): + if find_coalesced_group(pg_name, entries, _pg_guids, first_rank): expected_ranks.add(first_rank) done_ranks = set() all_coalesced_entries = {} @@ -200,13 +224,13 @@ def build_collectives( curr = expected_ranks.pop() done_ranks.add(curr) grp = ( - find_coalesced_group(pg_name, all_entries[curr]) # type: ignore[index] + find_coalesced_group(pg_name, all_entries[curr], _pg_guids, curr) # type: ignore[index] if curr in all_entries # type: ignore[comparison-overlap] else [] ) all_coalesced_entries[curr] = grp for index, entry in grp: - op = Op(entry, _memberships) + op = Op(entry, _memberships, pg_name) peer = None if op.type == "send": assert op._src_g == curr, (op._src_g, curr) @@ -222,6 +246,7 @@ def build_collectives( group_size=_groups[pg_name].size, groups=_groups, memberships=_memberships, + _pg_guids=_pg_guids, ) if match and mismatch[pg_name] == 0: @@ -244,16 +269,19 @@ def build_collectives( nccl_calls.extend(reversed(reversed_calls)) else: has_undecided_case = False - errors = Set() + errors = set() for o in expected_ranks.intersection(set(other_ranks)): for i, e in enumerate(all_entries[o]): # type: ignore[index] # step over ops from other PGs # only check match state when seq_id matches if ( - e["process_group"] == (pg_name, desc) + _pg_guids[(e["process_group"][0], o)] == pg_name + and e["process_group"][1] == desc and e["collective_seq_id"] == collective_seq_id ): - match_state = match_one_event(entries[0], e, _memberships) + match_state = match_one_event( + entries[0], e, _memberships, pg_name + ) if ( match_state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] @@ -265,15 +293,25 @@ def build_collectives( else: candidate_ranks.add(o) candidate_idx[o] = i - errors.add(match_state) + if match_state not in [ + MatchState.FULLY_MATCHED, + MatchState.UNDECIDED, + ]: + # Here we assume the current rank is not the source of the error. + # But it's possible that the current rank is the culprit, then users will + # see lots of normal ranks reported as culprit. + # TODO: we need to figure out a better way to handle the case mentioned above. + errors.add((o, match_state)) break # case one: not every rank join the collective or in the flight recorder. if (candidate_ranks | found_ranks) != expected_ranks: mismatch[pg_name] += 1 print( - f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name}", - f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)}", + f"Not all ranks joining collective for group {pg_name}:{desc} collective {profiling_name} ", + f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)} ", + f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", + f"\nCollective stack traces: \n{collective_frames}", ) elif len(candidate_ranks) == 1: # case two: alltoall or alltoall_base case. @@ -281,14 +319,22 @@ def build_collectives( alltoall_cases = [entries[0]] + [ all_entries[o][found_idx[o]] for o in found_ranks ] - check_result, input_numel, output_numel = check_size_alltoall( + fail_check, input_numel, output_numel = check_size_alltoall( alltoall_cases ) - if not check_result: + if major_v <= 2 and minor_v <= 3: + # We don't log the input/output sizes for alltoall before v2.4, + # so we don't consider the size mismatch as an error for now. + fail_check = False + if fail_check: + # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. mismatch[pg_name] += 1 print( - f"Input/output mismatch in the collective for group {pg_name}:{desc} collective {profiling_name}", - f"input_numel {input_numel} output_numel{output_numel}", + f"Input/output mismatch in the collective {record_id} ", + f"for group {pg_name}:{desc} collective {profiling_name} ", + f"input_numel {input_numel} output_numel {output_numel} ", + f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", + f"\nCollective stack traces: \n{collective_frames}", ) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) @@ -306,11 +352,16 @@ def build_collectives( candidate_idx.clear() candidate_ranks.clear() # case four: mismatch cases due to not same type, size mismatch or state mismatch. - else: - error_msg = ", ".join(error.name for error in errors) + elif len(errors) > 0: + mismatch[pg_name] += 1 + error_msg = ", ".join( + f"Error rank {error[0]}, {str(error[1])}" for error in errors + ) print( - f"Collective errors for group {pg_name}:{desc} collective {profiling_name}", - f"Found errors: {error_msg}", + f"Collective {record_id} errors for group {pg_name}:{desc} collective {profiling_name} ", + f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", + f"\nFound errors: {error_msg}\n", + f"\nCollective stack traces: \n{collective_frames} ", ) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) @@ -361,32 +412,37 @@ def build_collectives( return tracebacks, collectives, nccl_calls -def build_db(details: Dict[str, Dict[str, Any]], args: argparse.Namespace) -> Database: +def build_db( + details: Dict[str, Dict[str, Any]], args: argparse.Namespace, version: str +) -> Database: # temporary state used for building database entries = {} pg_config = {} - version = {} + version_by_ranks = {} for dump in details.values(): rank = dump["rank"] entries[rank] = dump["entries"] - version[rank] = dump["version"] + version_by_ranks[rank] = dump["version"] pg_config[rank] = dump["pg_config"] - check_version(version) - check_trace_from_beginning(entries) + # Ensure version is consistent across all ranks. + check_version(version_by_ranks, version) + entries = align_trace_from_beginning(entries) # flattened database - groups, _groups, memberships, _memberships = build_groups_memberships(pg_config) + groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( + pg_config + ) print("built groups, memberships") check_no_missing_dump_files(entries, memberships) if args.just_print_entries: - just_print_entries(entries, _groups, _memberships) + just_print_entries(entries, _groups, _memberships, _pg_guids, args) sys.exit(0) tracebacks, collectives, nccl_calls = build_collectives( - entries, _groups, _memberships + entries, _groups, _memberships, _pg_guids, version ) print("built collectives, nccl_calls") if args.verbose: diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index d296bd981b811c..618aa40b55be58 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +from typing import Optional, Sequence class JobConfig: @@ -16,19 +17,47 @@ def __init__(self: "JobConfig"): self.parser = argparse.ArgumentParser( description="PyTorch Flight recorder analyzing script." ) - self.parser.add_argument( - "-d", "--dir", help="Directory with flight recorder dumps" + "trace_dir", + help="Directory containing one trace file per rank, named with _.", + ) + self.parser.add_argument( + "--selected-ranks", + default=None, + nargs="+", + type=int, + help="List of ranks we want to show traces for.", + ) + self.parser.add_argument( + "--pg-filters", + default=None, + nargs="+", + type=str, + help="List of filter strings", ) self.parser.add_argument("-o", "--output", default=None) self.parser.add_argument( "-p", "--prefix", - help="prefix to strip such that rank can be extracted", - default="rank_", + help=( + "Common filename prefix to strip such that rank can be extracted. " + "If not specified, will attempt to infer a common prefix." + ), + default=None, ) self.parser.add_argument("-j", "--just_print_entries", action="store_true") self.parser.add_argument("-v", "--verbose", action="store_true") - def parse_args(self: "JobConfig") -> argparse.Namespace: - return self.parser.parse_args() + def parse_args( + self: "JobConfig", args: Optional[Sequence[str]] + ) -> argparse.Namespace: + args = self.parser.parse_args(args) + if args.selected_ranks is not None: + assert ( + args.just_print_entries + ), "Not support selecting ranks without printing entries" + if args.pg_filters is not None: + assert ( + args.just_print_entries + ), "Not support selecting pg filters without printing entries" + return args diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index e8129cc5a215fc..451f14df37f205 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -7,15 +7,16 @@ import gc import os import pickle +import re import time -from typing import Any, Dict, List, Union +import typing +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple, Union def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]: basename = os.path.basename(filename) - assert ( - basename.find(prefix) == 0 - ), f"args.prefix ({prefix}) must match the beginning of each filename ({basename})" + rank = int(basename[len(prefix) :]) host_name = f"host_rank{rank}" @@ -35,15 +36,47 @@ def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any] } -def read_dir(prefix: str, folder: str) -> Dict[str, Dict[str, Any]]: +exp = re.compile(r"([\w\-\_]*?)(\d+)$") + + +def _determine_prefix(files: List[str]) -> str: + """If the user doesn't specify a prefix, but does pass a dir full of similarly-prefixed files, we should be able to + infer the common prefix most of the time. But if we can't confidently infer, just fall back to requring the user + to specify it + """ + possible_prefixes: typing.DefaultDict[str, Set[int]] = defaultdict(set) + for f in files: + m = exp.search(f) + if m: + p, r = m.groups() + possible_prefixes[p].add(int(r)) + if len(possible_prefixes) == 1: + prefix = next(iter(possible_prefixes)) + print(f"Inferred common prefix {prefix}") + return prefix + else: + raise ValueError( + "Unable to automatically determine the common prefix for the trace file names. " + "Please specify --prefix argument manually" + ) + + +def read_dir( + prefix: Optional[str], folder: str +) -> Tuple[Dict[str, Dict[str, Any]], str]: gc.disable() details = {} t0 = time.time() + version = "" for root, _, files in os.walk(folder): + if prefix is None: + prefix = _determine_prefix(files) for f in files: - ta = time.time() + if f.find(prefix) != 0: + continue details[f] = read_dump(prefix, os.path.join(root, f)) - tb = time.time() - # print(f"read file {f} in {tb - ta}s") + if not version: + version = str(details[f]["version"]) + tb = time.time() print(f"loaded {len(files)} files in {tb - t0}s") - return details + return details, version diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index e55c2370f30cd0..1f2b75a05eb735 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +import math +from enum import auto, Enum from typing import ( # type: ignore[attr-defined] _eval_type, Any, @@ -12,6 +13,7 @@ Generic, List, NamedTuple, + Optional, Set, Tuple, Type, @@ -66,13 +68,13 @@ def from_type(cls, c: T) -> "TypeInfo": class Group(NamedTuple): - id: int + id: str desc: str size: int class Membership(NamedTuple): - group_id: Ref[Group] + group_id: str global_rank: int @@ -83,13 +85,13 @@ class Traceback(NamedTuple): class Collective(NamedTuple): id: int - group_id: Ref[Group] + group_id: str class NCCLCall(NamedTuple): id: int collective_id: Ref[Collective] - group_id: Ref[Group] + group_id: str global_rank: int # technically Ref[Process] once we have it traceback_id: Ref[Traceback] collective_type: str @@ -138,8 +140,7 @@ class Database(NamedTuple): "_reduce_scatter_base", "gather", "scatter", - "alltoall_base", - "alltoall", + "all_to_all", } P2P = { @@ -157,32 +158,25 @@ class MatchState(Enum): - SIZE_OR_SYNTAX_MISMATCH: There is a mismatch in input/output sizes or violation of collective syntax. - COLLECTIVE_STATE_MISMATCH: The states of the collective not same, such as one finished while another just started or scheduled. + - COLLECTIVE_DTYPE_MISMATCH: The data types of the collective input/output differ. - UNDECIDED: The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. """ - FULLY_MATCHED = 1 - COLLECTIVE_TYPE_MISMATCH = 2 - SIZE_OR_SYNTAX_MISMATCH = 3 - COLLECTIVE_STATE_MISMATCH = 4 - UNDECIDED = 5 - - -def check_size_evenly_broadcasting( - list1: List[Any], list2: List[Any], size: int -) -> bool: - ratio = None - for a, b in zip(list1, list2): - current_ratio = int(a) / int(b) - if current_ratio == 1: - continue - if current_ratio != size: - return False - elif ratio is None: - ratio = current_ratio - else: - return False - return True + FULLY_MATCHED = auto() + COLLECTIVE_TYPE_MISMATCH = auto() + SIZE_OR_SYNTAX_MISMATCH = auto() + COLLECTIVE_STATE_MISMATCH = auto() + COLLECTIVE_DTYPE_MISMATCH = auto() + UNDECIDED = auto() + + def __call__(self, culprit: Optional[str] = None) -> "MatchState": + # Make the enum instance callable to add culprit. + self.culprit = culprit + return self + + def __str__(self) -> str: + return f"Error type: {self.name}, Detail finding {self.culprit if self.culprit else ''}" class Op: @@ -194,7 +188,9 @@ class Op: nccl:recv 3<-0 """ - def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): + def __init__( + self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]], pg_name: str + ): profiling_name = event["profiling_name"] nccl, name = profiling_name.split(":") assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'" @@ -202,9 +198,7 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): type = parts[0] meta = parts[1] if len(parts) == 2 else None self.state = event["state"] - self.pg_name, _ = event["process_group"] - assert type in COLLECTIVES | P2P | { "coalesced" }, f"{type} is not a supported operation" @@ -217,10 +211,9 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): self._dst, self._src = int(d), int(s) else: self._src, self._dst = -1, -1 - pg_name, pg_desc = event["process_group"] + _, pg_desc = event["process_group"] self._init_global_src_dst(memberships[pg_name]) self.pg_size = len(memberships[pg_name]) - if type in P2P | COLLECTIVES: self.input_sizes = event["input_sizes"] self.output_sizes = event["output_sizes"] @@ -228,6 +221,8 @@ def __init__(self, event: Dict[Any, Any], memberships: Dict[str, Set[Any]]): self.input_sizes, self.output_sizes = None, None self.collective_seq_id = event["collective_seq_id"] self.p2p_seq_id = event["p2p_seq_id"] + self.input_dtypes = event["input_dtypes"] + self.output_dtypes = event["output_dtypes"] def _init_global_src_dst(self, pg_ranks: Set[Any]) -> None: pg_ranks = sorted(pg_ranks) @@ -246,10 +241,8 @@ def dst(self) -> int: def __repr__(self) -> str: if self.type in P2P: - return ( - f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes})" - ) - return f"{self.type}(input_sizes={self.input_sizes}, {self.state})" + return f"{self.type}(s={self._src_g} d={self._dst_g}, sz={self.input_sizes}, state={self.state})" + return f"{self.type}(input_sizes={self.input_sizes}, state={self.state})" def match(self, other: "Op") -> MatchState: # TODO: I think this can validly not match, @@ -282,34 +275,60 @@ def match(self, other: "Op") -> MatchState: ) elif self.type in COLLECTIVES: if self.type != other.type: - return MatchState.COLLECTIVE_TYPE_MISMATCH - if self.type in ["alltoall", "alltoall_base"]: + return MatchState.COLLECTIVE_TYPE_MISMATCH( + f"Type '{self.type}' and '{other.type}' do not match" + ) + if self.state != other.state: + # MatchState() + return MatchState.COLLECTIVE_STATE_MISMATCH( + f"States '{self.state}' '{other.state}' do not match" + ) + if ( + other.input_dtypes != other.output_dtypes + or self.input_dtypes != other.input_dtypes + or self.output_dtypes != other.output_dtypes + ): + return MatchState.COLLECTIVE_DTYPE_MISMATCH( + f"Dtypes '{self.input_dtypes}/{other.input_dtypes}' '{self.output_dtypes}/{other.output_dtypes}' do not match" + ) + if self.type == "all_to_all": return MatchState.UNDECIDED if self.type != "scatter" and self.input_sizes != other.input_sizes: - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input sizes '{self.input_sizes}' '{other.input_sizes}' do not match" + ) if self.type != "gather" and self.output_sizes != other.output_sizes: - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Output sizes '{self.output_sizes}' '{other.output_sizes}' do not match" + ) if self.type == "all_reduce" and self.input_sizes != other.output_sizes: - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input sizes '{self.input_sizes}' do not match output sizes '{other.output_sizes}'" + ) # TODO: need to consider uneven sharding for all-gather. # TODO: need to consider all_gather_into_tensor_coalesced (coalesced related) if self.type in [ "all_gather", "all_gather_base", - ] and not check_size_evenly_broadcasting( - other.output_sizes, self.input_sizes, self.pg_size + ] and not ( + math.prod(other.output_sizes[0]) + == math.prod(self.input_sizes[0]) * self.pg_size ): - return MatchState.SIZE_OR_SYNTAX_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input numel '{math.prod(other.input_sizes[0])} * pg size {self.pg_size}' " + f"do not match output numel '{math.prod(other.output_sizes[0])}'", + ) if self.type in [ "reduce_scatter", "_reduce_scatter_base", - ] and not check_size_evenly_broadcasting( - other.input_sizes, self.output_sizes, self.pg_size + ] and not ( + math.prod(other.input_sizes[0]) + == math.prod(self.output_sizes[0]) * self.pg_size ): - return MatchState.SIZE_OR_SYNTAX_MISMATCH - # TODO: need to add more checks for gather and scatter. - if self.state != other.state: - return MatchState.COLLECTIVE_STATE_MISMATCH + return MatchState.SIZE_OR_SYNTAX_MISMATCH( + f"Input numel '{math.prod(other.input_sizes[0])}' do not match output numel " + f"'{math.prod(other.output_sizes[0])} * pg size {self.pg_size}'", + ) elif self.type == "coalesced": return ( MatchState.FULLY_MATCHED diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index ef0e9a9f1380f5..87e3fc6a1c9664 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import argparse import math -from typing import Any, Dict, List, Set, Tuple # type: ignore[attr-defined] +from typing import Any, Dict, List, Set, Tuple from tools.flight_recorder.components.types import ( Group, @@ -22,13 +23,28 @@ print("tabulate is not installed. Proceeding without it.") +def format_frame(frame: Dict[str, str]) -> str: + name = frame["name"] + filename = frame["filename"] + line = frame["line"] + return f"{name} at {filename}:{line}" + + +def format_frames(frames: List[Dict[str, str]]) -> str: + formatted_frames = [] + for frame in frames: + formatted_frames.append(format_frame(frame)) + return "\n".join(formatted_frames) + + def match_one_event( event_a: Dict[Any, Any], event_b: Dict[Any, Any], memberships: Dict[str, Set[Any]], + pg_name: str, ) -> MatchState: - op_a = Op(event_a, memberships) - op_b = Op(event_b, memberships) + op_a = Op(event_a, memberships, pg_name) + op_b = Op(event_b, memberships, pg_name) return op_a.match(op_b) @@ -37,6 +53,7 @@ def match_coalesced_groups( group_size: int, groups: Dict[str, Group], memberships: Dict[str, Set[Any]], + _pg_guids: Dict[Tuple[str, int], str], ) -> bool: """ all_rank_events: { @@ -62,13 +79,22 @@ def match_coalesced_groups( rank1 [recv:0 (1000B), recv:0 (100B)] —> not okay """ all_ops = { - rank: [Op(e, memberships) for i, e in all_rank_events[rank]] + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for i, e in all_rank_events[rank] + ] for rank in all_rank_events } - def visualize_ops(match: bool) -> None: + def visualize_ops( + match: bool, + _pg_guids: Dict[Tuple[str, int], str], + ) -> None: all_ops = { - rank: [Op(e, memberships) for i, e in all_rank_events[rank]] + rank: [ + Op(e, memberships, _pg_guids[(e["process_group"][0], rank)]) + for i, e in all_rank_events[rank] + ] for rank in all_rank_events } @@ -80,8 +106,14 @@ def visualize_ops(match: bool) -> None: progress = False for r in all_ops: if len(all_ops[r]) > i: - _, event = all_rank_events[r][i] - row.append(Op(event, memberships)) + rank, event = all_rank_events[r][i] + row.append( + Op( + event, + memberships, + _pg_guids[(event["process_group"][0], rank)], + ) + ) progress = True else: row.append(None) # type: ignore[arg-type] @@ -130,10 +162,10 @@ def visualize_ops(match: bool) -> None: my_ops.pop(0) peer_ops.pop(match_idx) else: - visualize_ops(False) + visualize_ops(False, _pg_guids) return False - visualize_ops(True) + visualize_ops(True, _pg_guids) return True @@ -147,21 +179,27 @@ def check_size_alltoall(alltoall_cases: List[Dict[str, Any]]) -> Tuple[bool, int def find_coalesced_group( - pg_name: str, entries: List[Dict[str, Any]] + pg_name: str, + entries: List[Dict[str, Any]], + _pg_guids: Dict[Tuple[str, int], str], + rank: int, ) -> List[Tuple[int, Dict[str, Any]]]: """Given a list of entries, if the collective_seq_id of the first entry matches that of subsequent ones, build an return a list of entries terminating in a 'coalesced' op entry all sharing a collective_seq_id - TODO: handle p2p_seq_id v/s collective_seq_id separately here. """ found = [] collective_seq_id = None for i, e in enumerate(entries): - if e["process_group"][0] != pg_name: + if _pg_guids[(e["process_group"][0], rank)] != pg_name: continue elif collective_seq_id is None: - collective_seq_id = e["collective_seq_id"] + collective_seq_id = ( + e["p2p_seq_id"] if e["is_p2p"] else e["collective_seq_id"] + ) + found.append((i, e)) + elif not e["is_p2p"] and e["collective_seq_id"] == collective_seq_id: found.append((i, e)) - elif e["collective_seq_id"] == collective_seq_id: + elif e["is_p2p"] and e["p2p_seq_id"] == collective_seq_id: found.append((i, e)) else: break @@ -176,20 +214,35 @@ def just_print_entries( all_entries: Dict[int, List[Dict[str, Any]]], _groups: Dict[str, Group], _memberships: Dict[str, Set[Any]], + _pg_guids: Dict[Tuple[str, int], str], + args: argparse.Namespace, ) -> None: rows = [] ranks = sorted(all_entries.keys()) - headers = [f"Rank {rank}" for rank in ranks] + headers = [ + f"Rank {rank}" + for rank in ranks + if args.selected_ranks is None or rank in args.selected_ranks + ] progress = True while progress: progress = False row = [] for rank in ranks: + if args.selected_ranks is not None and rank not in args.selected_ranks: + continue if len(all_entries[rank]) == 0: row.append("") else: entry = all_entries[rank].pop(0) - row.append(str(Op(entry, _memberships))) + pg_name = _pg_guids[(entry["process_group"][0], rank)] + if ( + args.pg_filters is None + or entry["process_group"][1] in args.pg_filters + ): + row.append(str(Op(entry, _memberships, pg_name))) + else: + row.append("") progress = True if progress: rows.append(row) @@ -198,34 +251,69 @@ def just_print_entries( def check_no_missing_dump_files( - entries: Dict[str, Any], memberships: List[Membership] + entries: Dict[int, Any], memberships: List[Membership] ) -> None: all_ranks = set() for membership in memberships: - all_ranks.add(str(membership.global_rank)) - dumps_ranks = set(entries.keys()) + all_ranks.add(int(membership.global_rank)) + dumps_ranks = {int(key) for key in entries.keys()} assert ( dumps_ranks == all_ranks ), f"Missing dump files from ranks {all_ranks - dumps_ranks}" -def check_version(versions: Dict[str, Any]) -> None: - for rank, version in versions.items(): # noqa: PERF102 - major, minor = map(int, version.split(".")) - # assert major == 2, f"Rank {rank} unsupported version {version}" - # assert minor >= 0, f"Rank {rank} unsupported version {version}" +def check_version(version_by_ranks: Dict[str, str], version: str) -> None: + for rank, v in version_by_ranks.items(): + assert ( + v == version + ), f"Rank {rank} has different version {v} from the given version {version}" + +def get_version_detail(version: str) -> Tuple[int, int]: + version = version.split(".") + assert len(version) == 2, f"Invalid version {version}" + major, minor = map(int, version) + return major, minor -# TODO: We need to revisit this function to see if we still need it. -def check_trace_from_beginning(entries: Dict[str, Any]) -> bool: + +def align_trace_from_beginning( + entries: Dict[int, List[Dict[str, Any]]] +) -> Dict[int, List[Dict[str, Any]]]: + """ + Align the trace entries by record ID for entries. + This function takes a dictionary of rank names to lists of trace entries as input. + Each trace entry is a dictionary containing information about a collective operation, + including its unique identifier (`record_id` is monotonically increasing as we write into the ring buffer). + The function finds the largest starting point across all ranks by taking the maximum + `record_id` value of the first entry in each rank. Finally, it filters out any + entries with `record_id` values less than the maximum starting point. + The function returns the updated dictionary of sorted and filtered trace entries. + + Args: + entries (Dict[str, List[Dict[str, Any]]]): A dictionary of rank names to lists of trace entries. + + Returns: + entries (Dict[str, List[Dict[str, Any]]]): Entries sorted by record ID and filtered by the maximum starting point. + """ + + maximum_starting_record_id = 0 for rank in entries: + # Although this is a ring buffer, we already sort the entries by `record_id` when dumping, we just + # need to find the largest starting point. For example, if the buffer has the following entries: + # Rank 0: [0, 1, 2, 3, 4, 5, 6] + # Rank 1: [1, 2, 3, 4, 5, 6, 7] + # Rank 2: [2, 3, 4, 5, 6, 7, 8] + # Rank 3: [0, 1, 2, 3, 4, 5, None] + # Then we should start from collective 2 not 0 because any collective before, + # we don't have complete records from all ranks so we need to ignore them. first_record_id = entries[rank][0]["record_id"] - # TODO add more sequence information such that analysis can proceed even without complete buffer + maximum_starting_record_id = max(maximum_starting_record_id, first_record_id) - # assert first_record_id == 0, f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}." - if first_record_id != 0: - print( - f"Rank {rank} trace does not start at time 0 (first record is {first_record_id}." - ) - return False - return True + for rank in entries: + entries[rank] = [ + entry + for entry in entries[rank] + if entry["record_id"] >= maximum_starting_record_id + ] + + return entries diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py index bf58aee2a0f103..c2b8b81a9fa27e 100644 --- a/tools/flight_recorder/fr_trace.py +++ b/tools/flight_recorder/fr_trace.py @@ -28,8 +28,8 @@ - This script is versioned so that we can ensure our future changes to flight recorder are backwards compatible. """ -import argparse import pickle +from typing import Optional, Sequence from tools.flight_recorder.components.builder import build_db from tools.flight_recorder.components.config_manager import JobConfig @@ -37,14 +37,15 @@ from tools.flight_recorder.components.types import types -def main(args: argparse.Namespace) -> None: - details = read_dir(args.prefix, args.dir) - db = build_db(details, args) +def main(args: Optional[Sequence[str]] = None) -> None: + config = JobConfig() + args = config.parse_args(args) + details, version = read_dir(args.prefix, args.trace_dir) + db = build_db(details, args, version) if args.output: with open(args.output, "wb") as f: pickle.dump((types, db), f) if __name__ == "__main__": - config = JobConfig() - main(config.parse_args()) + main() diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index e4e15bcbc4f6da..bd2fcee5e51c4a 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -781,6 +781,9 @@ def gen_pyi( "_is_functional_tensor": [ "def _is_functional_tensor(t: Tensor) -> _bool: ..." ], + "_is_functional_tensor_base": [ + "def _is_functional_tensor_base(t: Tensor) -> _bool: ..." + ], "_from_functional_tensor": [ "def _from_functional_tensor(t: Tensor) -> Tensor: ..." ], diff --git a/tools/pytorch.version b/tools/pytorch.version deleted file mode 100644 index 3488ccfc56b1df..00000000000000 --- a/tools/pytorch.version +++ /dev/null @@ -1,31 +0,0 @@ -{ - global: - _TH*; - __TH*; - TH*; - *THP*; - *THCP*; - PyInit*; - init*; - state; - _ZGVZN2at*; - _ZN2at*; - _ZNK2at*Type*; - _ZNK2at*Tensor*; - _ZNK2at*Storage*; - _ZNK2at*Scalar*; - _ZNK2at*CUDA*; - *2at7Context*; - _ZTIN2at*; - _ZTIZN2at*; - _ZTSN2at*; - _ZTSPN2at*; - _ZTSZN2at*; - _ZTVN2at*; - _ZZN2at*; - _Z*torch*; - _Z*Tensor*; - _Z*tensor*; - local: - *; - }; diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index d50d9d54739979..4b605fe597505a 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -204,7 +204,6 @@ def generate( "UBSAN_FLAGS", "BLAS", "WITH_BLAS", - "BUILDING_WITH_TORCH_LIBS", "CUDA_HOST_COMPILER", "CUDA_NVCC_EXECUTABLE", "CUDA_SEPARABLE_COMPILATION", diff --git a/tools/test/test_executorch_types.py b/tools/test/test_executorch_types.py index c00a02cd500e1b..dedb19e21f3e6d 100644 --- a/tools/test/test_executorch_types.py +++ b/tools/test/test_executorch_types.py @@ -3,6 +3,7 @@ from torchgen import local from torchgen.api.types import ( BaseCType, + boolT, ConstRefCType, CType, longT, @@ -64,6 +65,10 @@ def test_argumenttype_type(self) -> None: "int[]? dims", NamedCType("dims", OptionalCType(ArrayRefCType(BaseCType(longT)))), ), + ( + "bool[3] output_mask", + NamedCType("output_mask", ArrayRefCType(BaseCType(boolT))), + ), ] for d in data: self._test_argumenttype_type(*d) diff --git a/tools/testing/target_determination/heuristics/utils.py b/tools/testing/target_determination/heuristics/utils.py index 17259756533d5c..7c408277559b1e 100644 --- a/tools/testing/target_determination/heuristics/utils.py +++ b/tools/testing/target_determination/heuristics/utils.py @@ -116,7 +116,7 @@ def get_issue_or_pr_body(number: int) -> str: # Despite the 'issues' in the link, this also works for PRs url = f"https://api.github.com/repos/pytorch/pytorch/issues/{number}" with urlopen(Request(url, headers=headers)) as conn: - body: str = json.loads(conn.read().decode())["body"] + body: str = json.loads(conn.read().decode())["body"] or "" return body diff --git a/tools/testing/update_slow_tests.py b/tools/testing/update_slow_tests.py index e88e85f5c5a090..d10daf6a8386b6 100644 --- a/tools/testing/update_slow_tests.py +++ b/tools/testing/update_slow_tests.py @@ -3,7 +3,7 @@ import subprocess import time from pathlib import Path -from typing import Any, cast, Dict, Optional, Tuple +from typing import Any, cast, Dict, List, Optional, Tuple import requests import rockset # type: ignore[import] @@ -93,7 +93,7 @@ def git_api( - url: str, params: Dict[str, str], type: str = "get", token: str = UPDATEBOT_TOKEN + url: str, params: Dict[str, Any], type: str = "get", token: str = UPDATEBOT_TOKEN ) -> Any: headers = { "Accept": "application/vnd.github.v3+json", @@ -147,6 +147,15 @@ def make_comment(source_repo: str, pr_number: int, msg: str) -> None: ) +def add_labels(source_repo: str, pr_number: int, labels: List[str]) -> None: + params = {"labels": labels} + git_api( + f"/repos/{source_repo}/issues/{pr_number}/labels", + params, + type="post", + ) + + def search_for_open_pr( source_repo: str, search_string: str ) -> Optional[Tuple[int, str]]: @@ -204,5 +213,6 @@ def search_for_open_pr( # no existing pr, so make a new one and approve it pr_num = make_pr("pytorch/pytorch", params) time.sleep(5) + add_labels("pytorch/pytorch", pr_num, ["ciflow/slow", "ci-no-td"]) approve_pr("pytorch/pytorch", pr_num) make_comment("pytorch/pytorch", pr_num, "@pytorchbot merge") diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 9a91b26d54cfb4..b8dfb8b706ba1a 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -340,11 +340,17 @@ endif() # in case of the split build we need to add compile definitions if(BUILD_LIBTORCHLESS) + + if(USE_UCC) + target_link_libraries(torch_python PRIVATE __caffe2_ucc) + target_compile_definitions(torch_python PRIVATE USE_UCC) + endif() + if(USE_UCC AND USE_C10D_UCC) target_compile_definitions(torch_python PRIVATE USE_C10D_UCC) endif() - if(USE_UCC AND USE_C10D_NCCL) + if(USE_NCCL AND USE_C10D_NCCL) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1638d3891e4084..cb61d8dbb70170 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1784,6 +1784,7 @@ def _mtia_getCurrentStream(device: _int) -> Stream: ... def _mtia_setCurrentStream(stream: Stream) -> None: ... def _mtia_getDefaultStream(device: _int) -> Stream: ... def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ... +def _mtia_getDeviceCapability(device: _int) -> Tuple[_int, _int]: ... # Defined in torch/csrc/mps/Module.cpp @@ -1831,6 +1832,7 @@ def _cuda_cudaHostAllocator() -> _int: ... def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... +def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ... @@ -2107,6 +2109,9 @@ def _xpu_getCurrentStream(device: _int) -> Tuple: ... def _xpu_getCurrentRawStream(device: _int) -> _int: ... def _xpu_synchronize(device: _int) -> None: ... def _xpu_emptyCache() -> None: ... +def _xpu_memoryStats(device: _int) -> Dict[str, Any]: ... +def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... +def _xpu_resetPeakMemoryStats(device: _int) -> None: ... class _XpuDeviceProperties: name: str diff --git a/torch/_C/_aoti.pyi b/torch/_C/_aoti.pyi index a5e782fe62123e..4e9f5e7c8671b1 100644 --- a/torch/_C/_aoti.pyi +++ b/torch/_C/_aoti.pyi @@ -18,3 +18,6 @@ def alloc_tensor_by_stealing_from_void_ptr( class AOTIModelContainerRunnerCpu: ... class AOTIModelContainerRunnerCuda: ... + +# Defined in torch/csrc/inductor/aoti_package/pybind.cpp +class AOTIModelPackageLoader: ... diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 353ad12f5b7d3e..6593222a119f4d 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -2,10 +2,11 @@ from torch.types import _bool, _int # Defined in torch/csrc/cpu/Module.cpp -def _is_cpu_support_avx2() -> _bool: ... -def _is_cpu_support_avx512() -> _bool: ... -def _is_cpu_support_avx512_vnni() -> _bool: ... -def _is_cpu_support_amx_tile() -> _bool: ... +def _is_avx2_supported() -> _bool: ... +def _is_avx512_supported() -> _bool: ... +def _is_avx512_vnni_supported() -> _bool: ... +def _is_avx512_bf16_supported() -> _bool: ... +def _is_amx_tile_supported() -> _bool: ... def _init_amx() -> _bool: ... def _L1d_cache_size() -> _int: ... def _L2_cache_size() -> _int: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 6033d969925972..f89f0b50c8582e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -296,15 +296,6 @@ class Backend: def _set_default_timeout(self, timeout: timedelta) -> None: ... class ProcessGroup: - class Options: - def __init__(self, backend: str, timeout: timedelta = ...) -> None: ... - @property - def backend(self) -> str: ... - @property - def _timeout(self) -> timedelta: ... - @_timeout.setter - def _timeout(self, val: timedelta) -> None: ... - class BackendType(Enum): UNDEFINED = ... GLOO = ... @@ -319,7 +310,6 @@ class ProcessGroup: store: Store, rank: int, size: int, - options: Options, ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... @@ -509,6 +499,7 @@ class ProcessGroup: @property def _device_types(self) -> list[torch.device]: ... def _get_backend(self, device: torch.device) -> Backend: ... + def _set_default_backend(self, backend_type: BackendType) -> None: ... def _register_backend( self, device: torch.device, @@ -533,7 +524,7 @@ class ProcessGroup: class ProcessGroupGloo(Backend): class Device: ... - class Options(ProcessGroup.Options): + class Options(Backend.Options): devices: list[ProcessGroupGloo.Device] threads: int @@ -563,7 +554,7 @@ class ProcessGroupNCCL(Backend): min_ctas: int max_ctas: int - class Options(ProcessGroup.Options): + class Options(Backend.Options): config: ProcessGroupNCCL.NCCLConfig is_high_priority_stream: bool split_from: ProcessGroupNCCL diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 499a66ed4643d4..548fc1f59e0ff1 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import types -from typing import NewType +from typing import NewType, Tuple from torch._dynamo.types import DynamoCallback, DynamoGuardHook @@ -8,11 +8,17 @@ from torch._dynamo.types import DynamoCallback, DynamoGuardHook # exposes the same interface. _PyInterpreterFrame = NewType("_PyInterpreterFrame", types.FrameType) +# For typechecking +SkipCodeRecursiveFlag = NewType("SkipCodeRecursiveFlag", object) +# Flag returned by Dynamo tracer to indicate to Dynamo eval frame that we should skip frames recursively. +skip_code_recursive_flag: SkipCodeRecursiveFlag + def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def reset_code(code: types.CodeType) -> None: ... def unsupported(obj1: object, obj2: object) -> object: ... def skip_code(code: types.CodeType) -> None: ... def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... +def set_context_frame(context: Tuple[int, int, int]) -> None: ... class _CacheEntry: def check_fn(self, *args, **kwargs): ... diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index d6aab0d547f727..918d913068e6df 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -67,7 +67,7 @@ class GuardManager: ) -> None: ... def add_global_state_guard(self, verbose_code_parts: list[str]) -> None: ... def add_torch_function_mode_stack_guard( - self, initial_stack, ignored_types, verbose_code_parts: list[str] + self, initial_stack, verbose_code_parts: list[str] ) -> None: ... class RootGuardManager(GuardManager): diff --git a/torch/_C/_instruction_counter.pyi b/torch/_C/_instruction_counter.pyi new file mode 100644 index 00000000000000..4e3c27567eb228 --- /dev/null +++ b/torch/_C/_instruction_counter.pyi @@ -0,0 +1,4 @@ +# Defined in torch/csrc/instruction_counter/Module.cpp + +def start() -> int: ... +def end(id: int) -> int: ... diff --git a/torch/_VF.py b/torch/_VF.py index 13fe080e6f41e5..94166b51f17865 100644 --- a/torch/_VF.py +++ b/torch/_VF.py @@ -20,12 +20,12 @@ class VFModule(types.ModuleType): vf: types.ModuleType - def __init__(self, name): + def __init__(self, name: str): super().__init__(name) self.vf = torch._C._VariableFunctions - def __getattr__(self, attr): - return getattr(self.vf, attr) + def __getattr__(self, name: str) -> object: + return getattr(self.vf, name) sys.modules[__name__] = VFModule(__name__) diff --git a/torch/__init__.py b/torch/__init__.py index 0a0776ce459f4f..64bfab708388a1 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2180,10 +2180,6 @@ def __init__(self, mode, options, dynamic): self.apply_mode(mode) self.apply_options(options) - # Stash the compiler_fn to be used for backend match guard. - from torch._inductor.compile_fx import compile_fx - - self.compiler_fn = compile_fx if self.config.get("triton.cudagraphs", False): os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" # FIXME: CUDA Graph does not work well with CUPTI teardown. @@ -2382,8 +2378,9 @@ def compile( There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints to debug. - - "max-autotune" is a mode that leverages Triton based matrix multiplications and convolutions - It enables CUDA graphs by default. + - "max-autotune" is a mode that leverages Triton or template based matrix multiplications + on supported devices and Triton based convolutions on GPU. + It enables CUDA graphs by default on GPU. - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs diff --git a/torch/_awaits/__init__.py b/torch/_awaits/__init__.py index 3770557d6f263d..b08067bdcf45a1 100644 --- a/torch/_awaits/__init__.py +++ b/torch/_awaits/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import cast, Callable, Generic, Type, TypeVar +from typing import Generic, TypeVar import torch diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 93bbec04a425be..a22289b75c4012 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -1,15 +1,26 @@ # mypy: allow-untyped-defs import inspect from collections import defaultdict -from functools import wraps +from functools import lru_cache, partial, wraps from itertools import chain -from typing import Callable, Dict, List, Sequence, TypeVar, Union +from typing import ( + Callable, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Set, + TypeVar, + Union, +) from typing_extensions import ParamSpec import torch import torch.library -from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket from torch._prims_common import CustomOutParamAnnotation +from torch._subclasses.functional_tensor import FunctionalTensor from torch.utils import _pytree as pytree @@ -20,6 +31,8 @@ "register_decomposition", "get_decompositions", "core_aten_decompositions", + "_decomp_table_to_post_autograd_aten", + "_special_op_to_preserve_cia", ] _T = TypeVar("_T") @@ -250,13 +263,184 @@ def remove_decompositions( import torch._refs +# Our strategy for deciding if we can preserve a op is following: +# 1. The op should be known statically that it is functional +# 2. If it is maybe aliasing, we decompose because we must know if an op +# is mutating or aliasing. +# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor +# decomp part. (https://github.com/pytorch/pytorch/issues/129431) +def _check_valid_to_preserve(op_overload): + if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: + return False + if op_overload in FunctionalTensor.metadata_fns: + return False + + alias_info = len( + [i for i in op_overload._schema.arguments if i.alias_info is not None] + ) + + is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable + + if is_mutating_or_aliasing: + return False + + if not torch._C._dispatch_has_kernel(op_overload.name()): + return False + + return True + + +def _is_cia_op(op: "OpOverload") -> bool: + return ( + torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ) + or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels + ) + + +@lru_cache(maxsize=1) +def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: + """ + This is an util function that gets the all CIA functional ops. + + The algorithm is in 2 steps: + 1. We first query C++ dispatcher to get the list of CIA ops + and then we call getattr on torch.ops.aten to lazily populate + them. + + 2. Sometimes, handful of ops have CIA registered in python dispatcher + but not on the C++ side, these can't be caught at the first step. + So we walk again to get the final list. + + Note that the output of this function should never be modified + """ + # First step to lazily populate torch.ops.aten + cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( + "CompositeImplicitAutograd" + ) + # Ignore quantized namespace ops + cia_ops = [name[6:] for name in cia_ops if name.startswith("aten::")] + # Materialize all CIA ops first + for op in cia_ops: + split_list = op.split(".") + # Sometime overload could be missing + assert len(split_list) == 1 or len(split_list) == 2 + op_name = split_list[0] + op_overload_name = "default" + if len(split_list) == 2: + op_overload_name = split_list[1] + + _ = getattr(getattr(torch.ops.aten, op_name), op_overload_name) + + # Second step to finally compile the list of all valid ops + cia_ops = set() + for op in torch.ops.aten: + op_packet = getattr(torch.ops.aten, op) + for overload in op_packet.overloads(): + op_overload = getattr(op_packet, overload) + if _check_valid_to_preserve(op_overload) and _is_cia_op(op_overload): + cia_ops.add(op_overload) + return cia_ops + + +def _get_decomp_for_cia(op): + # [NOTE] Seperating out func.decompose + # Ideally we should be able to just register func.decompose but + # we can't as this decomp is gonna be registered to the py_impl. + # As a result it will infinitely recurse. So we first check if the op + # has py_impl entry for CIA and if it is we use that first. If not, + # we register C++ query to py_impl. + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): + return op.py_kernels[dk] + + def _special_op_to_decompose_cia(*args, **kwargs): + kernel = kwargs["kernel"] + del kwargs["kernel"] + # Can't call kernel.decompose due to infinite recursion as + # we register this kernel to py_impl directly + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if torch._C._dispatch_has_kernel_for_dispatch_key( + kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + return kernel._op_dk(dk, *args, **kwargs) + else: + raise AssertionError( + f"Expected {kernel} to have CompositeImplicitAutograd kernel" + ) + + return partial(_special_op_to_decompose_cia, kernel=op) + + # See NOTE [Core ATen Ops] # # list was copied from torch/_inductor/decomposition.py # excluding decompositions that results in prim ops # Resulting opset of decomposition is core aten ops def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: + decomp_table = _core_aten_decompositions_post_autograd() + + # If it is fbcode change, we return the old decomposition list + from torch._inductor import config + + if config.is_fbcode(): + return decomp_table + + aten = torch.ops.aten + + # We are deleting custom decomp in core_aten_decomp + # for CIA ops but it should be fine technically + # because this table is only meant to be used in export context + # in which we really carefully control the decomp behaviour + # In any case, C++ decomps should be preferred + cia_ops_that_should_be_removed = [ + aten.all.dimname, + aten.index_add.dimname, + aten.index_copy.dimname, + aten.index_fill.Dimname_Scalar, + aten.index_fill.Dimname_Tensor, + aten.norm.names_ScalarOpt_dim_dtype, + aten.norm.names_ScalarOpt_dim, + aten.silu_backward.default, + aten.std.default, + aten.std.dim, + aten.std.names_dim, + aten.std.correction_names, + aten.std_mean.default, + aten.std_mean.dim, + aten.std_mean.names_dim, + aten.std_mean.correction_names, + aten.upsample_bilinear2d.vec, + aten.upsample_trilinear3d.vec, + ] + + for k in list(decomp_table.keys()): + if k in cia_ops_that_should_be_removed: + del decomp_table[k] + + for op in _collect_all_valid_cia_ops(): + decomp_table[op] = _get_decomp_for_cia(op) + return decomp_table + + +# This table is a stop-gap table which replicates +# the old behaviour of post-dispatch IR. +# This table contains all functional CIA ops mapping +# to their default decomp. In old export, this will +# be decomposed implicitly. +def _decomp_table_to_post_autograd_aten(): + decomp_table = {} + for k in _collect_all_valid_cia_ops(): + decomp_table[k] = _get_decomp_for_cia(k) + return decomp_table + + +def _core_aten_decompositions_post_autograd() -> ( + Dict[torch._ops.OperatorBase, Callable] +): aten = torch.ops.aten + # TODO Delete all mutating or CIA ops from this list return get_decompositions( [ aten.addcdiv, @@ -452,6 +636,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.threshold_backward, aten.trace, aten.transpose.int, + aten.transpose_copy, aten.tril, aten.tril_, aten.triu, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index e4eb73dc624ed0..f3c09d762f33ae 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1409,20 +1409,17 @@ def split_with_sizes( sum(split_sizes) == self.shape[dim], lambda: f"Split sizes add up to {sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", ) - num_splits = len(split_sizes) - splits = [] - start_idx = 0 - # Avoid importing sympy at a module level - from torch.fx.experimental.symbolic_shapes import expect_true - - for i in range(num_splits): - length = split_sizes[i] - # We know this is true thanks to the sum, but this assertion helps - # out our internal reasoning - expect_true(start_idx + length <= self.shape[dim]) - splits.append(self.narrow(dim, start_idx, length)) - start_idx += length + splits = [] + offset = self.storage_offset() + + for split_size in split_sizes: + new_shape = list(self.shape) + new_shape[dim] = split_size + # We reimplement narrow here to avoid a lot of checks in the + # decomposition of narrow which calls slice_in_dim and slice + splits.append(self.as_strided(new_shape, self.stride(), offset)) + offset = offset + self.stride()[dim] * split_size return splits @@ -4367,10 +4364,10 @@ def matmul(tensor1, tensor2, *, is_out=False): if t2_is_matrix: # This copies if we perform a 2D @ 3D and the first tensor requires_grad # See should_fold native/LinearAlgebra.cpp for why. - output = t1_folded.mm(t2).view(output_shape) + output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape) return output.mT.contiguous() if transpose else output else: - return t1_folded.mv(t2).view(output_shape) + return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape) elif dim_tensor1 >= 1 and dim_tensor2 >= 1: # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); @@ -4911,9 +4908,7 @@ def scaled_dot_product_flash_attention_for_cpu( # Why this change? # In pre-dispatch export scaled_dot_product_attention is executed via # * flash_attention. - # flash_attention allocates output tensor as (N, L, H, E) - # it then transposes that to get (N, H, L, E) which is supposed to be the return - # tensor dim for scaled_dot_product_attention + # flash_attention allocates output tensor as (N, H, L, E) (see PR #134656) # assume x: [N, H, L, E] is the output sdpa # In MHA code, this output is then permuted via (2, 0, 1, 3) to get # (L, N, H, E) dim tensor @@ -4929,20 +4924,16 @@ def scaled_dot_product_flash_attention_for_cpu( # subsequent view is not valid and the export fails # solution is to maintain the return tensor view from the decomp to be # exactly same as *flash* variant. - # flash variants output is contiguous as [N, L, H, E] - # _match variant out is contiguous as [N, H, L, E] - # out = out.transpose(1, 2).contiguous gets output as contiguous - # in [N, L, H, E]. - # Subsrequent transpose(1, 2) then returns a view on which - # aforementioned code snippet, as showm below, is valid - # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via - # x = x.view(L * N, H * E) # Really the invariant you want to maintain is: # pre-dispatch op-output and its decomposed representation must # return tensor with same view and dims - output = output.transpose(1, 2).contiguous(memory_format=torch.contiguous_format) - return (output.transpose(1, 2), attn) + output = ( + output.permute(2, 0, 1, 3) + .contiguous(memory_format=torch.contiguous_format) + .permute(1, 2, 0, 3) + ) + return output, attn def register_inplace(aten_op, outplace_op): diff --git a/torch/_deploy.py b/torch/_deploy.py index 3f8adc420672db..0443a2447d00dd 100644 --- a/torch/_deploy.py +++ b/torch/_deploy.py @@ -24,10 +24,8 @@ def persistent_id(obj): if isinstance(obj, torch.storage.TypedStorage): # TODO: Once we decide to break serialization FC, we can # remove this case - storage = obj._untyped_storage dtype = obj.dtype else: - storage = obj dtype = torch.uint8 serialized_storages.append(obj) diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 00286a32750d1e..7f58ba7f7bf7f1 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -107,4 +107,3 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() - convert_frame.disabled_codes.clear() diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index 38f763f07898a1..c698ded100943a 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -52,6 +52,9 @@ class TraceWrapped(HigherOrderOperator): def __init__(self): super().__init__("trace_wrapped") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + # TODO(jansel): need to ensure this does not get DCEed _trace_wrapped_op = TraceWrapped() diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 87214f7d59774e..18784498862e39 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -31,6 +31,18 @@ def eager(gm, fake_tensor_inputs, **kwargs): return gm.forward +def make_eager_backend_with_torch_function_mode(mode): + """Used to trace HOPs (cond and while) for eager exectution, the metadata + TF mode mutates vars outside of the scope of the HOP, and we can't have graph breaks + in the HOP, so we need to externally run this mode and not trace it.""" + + def fn(gm, fake_tensor_inputs, **kwargs): + with mode: + return gm.forward + + return fn + + @register_backend def eager_noexcept(gm, fake_tensor_inputs, **kwargs): if kwargs: diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index e3538a4f6730f6..749d11937ea330 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -1,13 +1,18 @@ # mypy: ignore-errors import functools +import logging import sys +from importlib.metadata import EntryPoint from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple import torch from torch import fx +log = logging.getLogger(__name__) + + class CompiledFn(Protocol): def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: ... @@ -15,7 +20,8 @@ def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn] -_BACKENDS: Dict[str, CompilerFn] = {} +_BACKENDS: Dict[str, Optional[EntryPoint]] = {} +_COMPILER_FNS: Dict[str, CompilerFn] = {} def register_backend( @@ -39,8 +45,10 @@ def register_backend( return functools.partial(register_backend, name=name, tags=tags) assert callable(compiler_fn) name = name or compiler_fn.__name__ - assert name not in _BACKENDS, f"duplicate name: {name}" - _BACKENDS[name] = compiler_fn + assert name not in _COMPILER_FNS, f"duplicate name: {name}" + if compiler_fn not in _BACKENDS: + _BACKENDS[name] = None + _COMPILER_FNS[name] = compiler_fn compiler_fn._tags = tuple(tags) return compiler_fn @@ -56,13 +64,15 @@ def lookup_backend(compiler_fn): if isinstance(compiler_fn, str): if compiler_fn not in _BACKENDS: _lazy_import() - if compiler_fn not in _BACKENDS: - _lazy_import_entry_point(compiler_fn) if compiler_fn not in _BACKENDS: from ..exc import InvalidBackend raise InvalidBackend(name=compiler_fn) - compiler_fn = _BACKENDS[compiler_fn] + + if compiler_fn not in _COMPILER_FNS: + entry_point = _BACKENDS[compiler_fn] + register_backend(compiler_fn=entry_point.load(), name=compiler_fn) + compiler_fn = _COMPILER_FNS[compiler_fn] return compiler_fn @@ -74,13 +84,14 @@ def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: """ _lazy_import() exclude_tags = set(exclude_tags or ()) - return sorted( - [ - name - for name, backend in _BACKENDS.items() - if not exclude_tags.intersection(backend._tags) - ] - ) + + backends = [ + name + for name in _BACKENDS.keys() + if name not in _COMPILER_FNS + or not exclude_tags.intersection(_COMPILER_FNS[name]._tags) + ] + return sorted(backends) @functools.lru_cache(None) @@ -94,22 +105,21 @@ def _lazy_import(): assert dynamo_minifier_backend is not None + _discover_entrypoint_backends() + @functools.lru_cache(None) -def _lazy_import_entry_point(backend_name: str): +def _discover_entrypoint_backends(): + # importing here so it will pick up the mocked version in test_backends.py from importlib.metadata import entry_points - compiler_fn = None group_name = "torch_dynamo_backends" if sys.version_info < (3, 10): - backend_eps = entry_points() - eps = [ep for ep in backend_eps.get(group_name, ()) if ep.name == backend_name] - if len(eps) > 0: - compiler_fn = eps[0].load() + eps = entry_points() + eps = eps[group_name] if group_name in eps else [] + eps = {ep.name: ep for ep in eps} else: - backend_eps = entry_points(group=group_name) - if backend_name in backend_eps.names: - compiler_fn = backend_eps[backend_name].load() - - if compiler_fn is not None and backend_name not in list_backends(()): - register_backend(compiler_fn=compiler_fn, name=backend_name) + eps = entry_points(group=group_name) + eps = {name: eps[name] for name in eps.names} + for backend_name in eps: + _BACKENDS[backend_name] = eps[backend_name] diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e68581762372e3..e7c5d2414f6e2b 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -237,14 +237,14 @@ def post_acc_grad_hook(self, input, hook_id): assert isinstance(input, torch.Tensor) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] - proxies = self.proxy_call_hook( + proxy = self.proxy_call_hook( hook, input, hook_type="post_acc_grad_hook", ) with disable_proxy_modes_tracing(): input = [maybe_clone(input)] - self.bind_tensors_to_proxies(input, proxies) + self.bind_tensors_to_proxies(input, [proxy]) return input # Note: [Compiled autograd and cudagraphs] diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index f1a91e358c6d96..972d79d48fa8b2 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -233,10 +233,9 @@ def print_value_stack(self, *, file=None, stacklevel=0): NB: Stack grows downwards in our print """ - # TODO: improve printing tx = self.__get_tx(stacklevel) for s in tx.stack: - print(f"- {s}", file=file) + print(f"- {s.debug_repr()}", file=file) def print_locals(self, *, file=None, stacklevel=0): """ @@ -244,10 +243,9 @@ def print_locals(self, *, file=None, stacklevel=0): By default this view is very limited; you can get more information about any individual local using get_local(). """ - # TODO: improve by improving the VariableTracker printing tx = self.__get_tx(stacklevel) for k, v in tx.symbolic_locals.items(): - print(f"{k} = {v}", file=file) + print(f"{k} = {v.debug_repr()}", file=file) def print_bt(self, *, file=None, stacklevel=0): """ diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 88081a87290070..2ba29961af36e9 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -48,6 +48,9 @@ def is_fbcode(): # [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps accumulated_cache_size_limit = 256 +# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit +skip_code_recursive_on_cache_limit_hit = True + # whether or not to specialize on int inputs. This only has an effect with # dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int # inputs. Note that assume_static_by_default will also cause ints to get @@ -98,7 +101,7 @@ def is_fbcode(): allow_ignore_mark_dynamic = False # Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing) -guard_nn_modules = False if is_fbcode() else True +guard_nn_modules = True # Uses CPython internal dictionary tags to detect mutation. There is some # overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag. @@ -374,6 +377,10 @@ def _get_optimize_ddp_mode(): # Inline inbuilt nn modules inline_inbuilt_nn_modules = not is_fbcode() +# When set, total compile time instruction count is recorded using +# torch._dynamo.utilsCompileTimeInstructionCounter. +record_compile_time_instruction_count = False + def default_debug_dir_root(): # [@compile_ignored: debug] diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 2f0da2be866db1..1b71c42b9ac5a7 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -2,6 +2,7 @@ from __future__ import annotations import collections +import contextlib import cProfile import dis import functools @@ -27,10 +28,12 @@ import torch._logging from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg +from torch._dynamo.utils import CompileTimeInstructionCounter from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._logging import structured from torch._utils_internal import ( compile_time_strobelight_meta, + justknobs_check, maybe_upload_prof_stats_to_manifold, signpost_event, ) @@ -66,8 +69,10 @@ from .exc import ( augment_exc_message, BackendCompilerFailed, + CacheLimitExceeded, format_error_msg, InternalTorchDynamoError, + SkipCodeRecursiveException, TorchRuntimeError, UncapturedHigherOrderOpError, unimplemented, @@ -107,6 +112,7 @@ troubleshooting_url, write_record_to_file, ) +from .variables.torch_function import torch_function_mode_stack_state_mgr np: Optional[ModuleType] @@ -160,7 +166,6 @@ def clear(self) -> None: input_codes = Tracker() output_codes = Tracker() -disabled_codes: Dict[int, Callable[..., Any]] = {} initial_global_state: Optional[GlobalStateGuard] = None @@ -206,10 +211,19 @@ def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: prior_fwd_from_src = torch.fx.graph_module._forward_from_src torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result cleanup = setup_compile_debug() + exit_stack = contextlib.ExitStack() + exit_stack.enter_context( + torch.fx._symbolic_trace._maybe_revert_all_patches() + ) + exit_stack.enter_context(torch_function_mode_stack_state_mgr) try: return fn(*args, **kwargs) finally: cleanup.close() + assert ( + torch._C._len_torch_function_stack() == 0 + ), "Torch function mode stack state changed while dynamo tracing, please report a bug" + exit_stack.close() torch._C._set_grad_enabled(prior_grad_mode) torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode) torch.use_deterministic_algorithms( @@ -595,6 +609,10 @@ def _compile( output: Optional[OutputGraph] = None tracer: Optional[InstructionTranslator] = None + tf_mode_stack: List[ + torch.overrides.TorchFunctionMode + ] = torch.overrides._get_current_function_mode_stack() + @preserve_global_state def transform( instructions: List[Instruction], code_options: Dict[str, object] @@ -608,6 +626,7 @@ def transform( locals, globals, builtins, + tf_mode_stack, code_options, compiler_fn, one_graph, @@ -652,7 +671,8 @@ def compile_inner( transform: Callable[[List[Instruction], Dict[str, Any]], Any], ) -> Optional[GuardedCode]: with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"): - return _compile_inner(code, one_graph, hooks, transform) + with CompileTimeInstructionCounter.record(): + return _compile_inner(code, one_graph, hooks, transform) @compile_time_strobelight_meta(phase_name="compile_inner") @maybe_cprofile @@ -842,7 +862,13 @@ def format_guard_failures() -> str: format_guard_failures(), troubleshooting_url, ) - unimplemented(f"{limit_type} reached") + if config.skip_code_recursive_on_cache_limit_hit and justknobs_check( + "pytorch/compiler:skip_code_recursive_on_cache_limit_hit" + ): + raise CacheLimitExceeded(f"{limit_type} reached") + else: + # do not recursively skip frames + unimplemented(f"{limit_type} reached") log.debug( "torchdynamo start compiling %s %s:%s, stack (elided %s frames):\n%s", @@ -906,34 +932,44 @@ def format_guard_failures() -> str: try: guarded_code = compile_inner(code, one_graph, hooks, transform) return guarded_code - except ( - Unsupported, - TorchRuntimeError, - BackendCompilerFailed, - AssertionError, - ConstraintViolationError, - GuardOnDataDependentSymNode, - ValidationException, - UncapturedHigherOrderOpError, - BisectValidationException, - ) as e: - fail_type = str(type(e)) - fail_reason = str(e) - exception_handler(e, code, frame, export=export) - fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( - e, compile_id - ) - raise except Exception as e: - fail_type = str(type(e)) + fail_type = type(e).__qualname__ fail_reason = str(e) + # NB: e's msg is mutated here to add user stack, but we DON'T want + # that stack in the Scuba logged fail_reason exception_handler(e, code, frame, export=export) + # NB: this is the post-mutation exception + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_error", + "encoding": "string", + }, + payload_fn=lambda: traceback.format_exc(), + ) fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message( e, compile_id ) - raise InternalTorchDynamoError(str(e)).with_traceback( - e.__traceback__ - ) from None + if isinstance( + e, + ( + Unsupported, + TorchRuntimeError, + BackendCompilerFailed, + AssertionError, + ConstraintViolationError, + GuardOnDataDependentSymNode, + ValidationException, + UncapturedHigherOrderOpError, + BisectValidationException, + ), + ): + raise + else: + # Rewrap for clarity + raise InternalTorchDynamoError( + f"{type(e).__qualname__}: {str(e)}" + ).with_traceback(e.__traceback__) from None finally: if tracer: tracer.output.local_scope = {} @@ -971,6 +1007,9 @@ def format_guard_failures() -> str: ] - start_possibly_missed_reinplacing_opportunities ) + remote_cache_time_saved = frame_phase_timing[frame_key].get( + "remote_cache_time_saved", 0 + ) else: guard_count = None shape_env_guard_count = None @@ -987,6 +1026,11 @@ def format_guard_failures() -> str: # If compilation failed, the entire time is wasted dynamo_time_before_restart = time.time() - start_time possibly_missed_reinplacing_opportunities = None + remote_cache_time_saved = None + + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) metrics = CompilationMetrics( str(compile_id), @@ -1016,6 +1060,8 @@ def format_guard_failures() -> str: dynamo_time_before_restart, guarded_code is not None, possibly_missed_reinplacing_opportunities, + remote_cache_time_saved, + structured_logging_overhead_s, ) record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() @@ -1038,7 +1084,9 @@ def __call__( hooks: Hooks, frame_state: Dict[str, Union[int, FrameStateSizeEntry]], skip: int = 0, - ) -> Optional[GuardedCode]: + ) -> Optional[ + Union[GuardedCode, torch._C._dynamo.eval_frame.SkipCodeRecursiveFlag] + ]: counters["frames"]["total"] += 1 try: result = self._inner_convert( @@ -1103,6 +1151,12 @@ def __call__( log.info(error_msg, exc_info=True) else: log.warning(error_msg, exc_info=True) + + # If we encounter SkipCodeRecursiveException, return skip_code_recursive_flag + # to signal to Dynamo eval frame to skip the current frame and any recursive calls. + if isinstance(e, SkipCodeRecursiveException): + return torch._C._dynamo.eval_frame.skip_code_recursive_flag + return None diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index cfec563729a480..94687ff2747bf3 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" - +import atexit import copy +import cProfile import functools import getpass import inspect @@ -10,6 +11,7 @@ import os import re import subprocess +import sys import tempfile import textwrap from collections import Counter @@ -361,7 +363,9 @@ def same_two_models( fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) except Exception: if require_fp64: - raise RuntimeError("Could not generate fp64 outputs") # noqa: B904 + raise RuntimeError( # noqa: B904 + "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False" + ) log.warning("Could not generate fp64 outputs") try: @@ -780,3 +784,41 @@ def gen_tensor(shape, dtype) -> Tensor: setattr(container, attr_name, gen_tensor(shape, dtype)) return kwargs + + +def profile_to_file(filename: str) -> Callable[[T], T]: + """ + Decorator to cProfile a given function and save the result to disk on process exit. + + Args: + filename: filename to save profile to + """ + prof = cProfile.Profile() + filename = os.path.abspath(os.path.expanduser(filename)) + + def decorator(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + prof.enable() + try: + return fn(*args, **kwargs) + finally: + prof.disable() + + return wrapper + + def save_it(): + prof.dump_stats(filename) + sys.stderr.write( + textwrap.dedent( + f"""\ + Wrote profile to {filename}, view with: + + snakeviz {filename} + + """ + ) + ) + + atexit.register(save_it) + return decorator diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 6b557d278a65cf..67d6c0f27a4c26 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -3,14 +3,13 @@ import functools import inspect from dataclasses import dataclass -from typing import Any, Callable, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar import torch from torch.utils._python_dispatch import is_traceable_wrapper_subclass from . import trace_rules, variables from .comptime import comptime -from .convert_frame import disabled_codes from .eval_frame import DisableContext, innermost_fn, RunOnlyContext from .exc import IncorrectUsage from .external_utils import is_compiling @@ -18,6 +17,8 @@ if TYPE_CHECKING: + from types import FunctionType + from torch._C._dynamo.eval_frame import ( # noqa: F401 reset_code, set_eval_frame, @@ -25,6 +26,8 @@ skip_code, unsupported, ) + + from .variables import VariableTracker else: for name in dir(torch._C._dynamo.eval_frame): if name.startswith("__"): @@ -56,15 +59,9 @@ def disable(fn=None, recursive=True): """ if recursive: if fn is not None: - id_fn = id(fn) - if cached_fn := disabled_codes.get(id_fn): - return cached_fn - fn = innermost_fn(fn) assert callable(fn) - out = DisableContext()(fn) - disabled_codes[id_fn] = out - return out + return DisableContext()(fn) return DisableContext() else: return skip(fn) @@ -175,7 +172,10 @@ def forbid_in_graph(fn): def substitute_in_graph( original_fn: _F, *, + can_constant_fold_through: bool = False, skip_signature_check: bool = False, + # type that is embedded in the Python interpreter + is_embedded_type: bool = False, # internal use only ) -> Callable[[_F], _F]: """ Register a polyfill handler for a function, usually a C function from the C extension, to be @@ -194,6 +194,10 @@ def substitute_in_graph( Args: original_fn (callable): The original function, usually a C function, to register a polyfill handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. skip_signature_check (bool, optional): Whether to skip the signature check between the original function and the polyfill handler. Defaults to ``False``. @@ -221,10 +225,22 @@ def substitute_in_graph( >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 2 """ - if not is_function(original_fn): + if not is_function(original_fn) and not ( + is_embedded_type and inspect.isclass(original_fn) + ): raise TypeError( f"substitute_in_graph expects a function but got {type(original_fn)!r}" ) + if is_embedded_type: + if not inspect.isclass(original_fn): + raise TypeError( + f"substitute_in_graph expects a class but got {type(original_fn)!r}" + ) + + from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS + + if id(original_fn) in ITERTOOLS_TYPE_IDS: + ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn)) def wrapper(traceable_fn: _F) -> _F: if not is_function(traceable_fn): @@ -282,7 +298,10 @@ def sig_ident(sig): ) from torch._dynamo.guards import GuardBuilder - from torch._dynamo.trace_rules import get_torch_obj_rule_map + from torch._dynamo.trace_rules import ( + _polyfilled_function_ids, + get_torch_obj_rule_map, + ) from torch._dynamo.variables import PolyfilledFunctionVariable from torch._dynamo.variables.builder import VariableBuilder @@ -293,13 +312,17 @@ def sig_ident(sig): "already registered in VariableBuilder's id dispatch map" ) - rule_map = get_torch_obj_rule_map() + if id(original_fn) in _polyfilled_function_ids: + raise ValueError(f"Duplicate polyfilled object {original_fn}") + + rule_map: Dict[Any, Type[VariableTracker]] = get_torch_obj_rule_map() if original_fn in rule_map: raise ValueError( f"Duplicate object {original_fn} with different rules: " f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}" ) + polyfill_handlers: Dict[Callable[..., Any], FunctionType] polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers() if original_fn in polyfill_handlers: raise ValueError( @@ -313,19 +336,22 @@ def sig_ident(sig): def wrapped(*args, **kwargs): return original_fn(*args, **kwargs) - def dispatch_fn(self, value): + def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable: return PolyfilledFunctionVariable( value, source=self.source, - **self.install_guards(GuardBuilder.CLOSURE_MATCH), + **self.install_guards(GuardBuilder.FUNCTION_MATCH), ) id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn + _polyfilled_function_ids.add(id(original_fn)) + _polyfilled_function_ids.add(id(wrapped)) rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable - polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = traceable_fn + polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment] wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined] wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined] + wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined] return wrapped # type: ignore[return-value] @@ -454,8 +480,10 @@ def maybe_mark_dynamic(t, index): def mark_static(t, index=None): """ - Mark a tensor as having a static dim. + Mark a tensor as having a static dim or mark a nn module class as static. + For tensors + =========== This will prevent us from attempting to compile it dynamically when dynamic=True; this can improve trace-time performance. @@ -463,6 +491,20 @@ def mark_static(t, index=None): Unlike mark_dynamic, this can be done inside a graph, in which case it induces specialization on the tensor. + + For nn.Module classes + ===================== + For static nn.Module classes, TorchDynamo assumes that the module instance + attributes will not be modified after compilation. This will ensure that + TorchDynamo keeps integer attributes CONSTANT and not symints. + + From TorchDynamo implementation side, the instances of static-marked + nn.Module class will be converted to UnspecializedBuiltinNNModuleVariable, + which have the same properties. + + Note that we still have to guard on the attributes, because different + instances of the nn.Module can have different values of the attributes. The + key point here is that the attributes are static. """ if is_compiling(): if index is None: @@ -477,11 +519,20 @@ def mark_static(t, index=None): # TODO: Make this configurable via a supported public API _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index) + if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module): + t._dynamo_marked_static = True + return t + + if not isinstance(t, torch.Tensor): + raise TypeError( + f"mark_static expects a tensor/nn.Module class but recieved {type(t)}" + ) + if isinstance(index, int): if not hasattr(t, "_dynamo_static_indices"): - t._dynamo_static_indices = set() + t._dynamo_static_indices = set() # type: ignore[attr-defined] # TODO(voz): Should we bounds check? - t._dynamo_static_indices.add(index) + t._dynamo_static_indices.add(index) # type: ignore[attr-defined] elif index is None: for i in range(t.dim()): mark_static(t, i) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 5670172c49c52c..00975cda550818 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -123,6 +123,10 @@ def get_compute_capability(device: _device_t = None): def is_bf16_supported(including_emulation: bool = False): raise NotImplementedError + @staticmethod + def memory_allocated(device: _device_t = None) -> int: + raise NotImplementedError + class DeviceGuard: """ @@ -202,6 +206,7 @@ def get_device_properties(device: _device_t = None): get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] + memory_allocated = staticmethod(torch.cuda.memory_allocated) is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] # Can be mock patched by @patch decorator. @@ -273,6 +278,7 @@ def get_device_properties(device: _device_t = None): get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type] exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type] maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type] + memory_allocated = staticmethod(torch.xpu.memory_allocated) # Can be mock patched by @patch decorator. @staticmethod diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index ed0f89097987f9..c04e4ccb00a97a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -104,7 +104,7 @@ def _maybe_set_eval_frame(callback: DynamoCallback): from torch._C._dynamo.eval_frame import set_eval_frame if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): - log.warning( + torch._dynamo.utils.warn_once( "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame" ) return callback diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index c281a9d68c4868..0d2108ada9e10e 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -168,6 +168,14 @@ def __init__(self, error_type: UserErrorType, msg, case_name=None) -> None: self.message = msg +class SkipCodeRecursiveException(TorchDynamoException): + pass + + +class CacheLimitExceeded(SkipCodeRecursiveException, Unsupported): + pass + + class UnsafeScriptObjectError(TorchDynamoException): pass diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d9cc7eb6d44d16..3dcc9a032f208f 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -65,6 +65,7 @@ Source, ) from torch._logging import structured +from torch._utils_internal import justknobs_check from torch.fx.experimental.symbolic_shapes import ( EqualityConstraint, is_symbolic, @@ -76,7 +77,9 @@ from . import config, convert_frame, exc, mutation_guard from .eval_frame import set_guard_error_hook from .source import ( + AttrProxySource, AttrSource, + CallFunctionNoArgsSource, ChainedSource, ConstDictKeySource, DefaultsSource, @@ -95,6 +98,7 @@ ScriptObjectQualifiedNameSource, ShapeEnvSource, SubclassAttrListSource, + TorchFunctionModeStackSource, TupleIteratorGetItemSource, TypeSource, UnspecializedBuiltinNNModuleSource, @@ -108,6 +112,7 @@ dict_keys_repr, get_custom_getattr, get_torch_function_mode_stack, + get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, @@ -311,6 +316,7 @@ def uninteresting_files(): "___dict_contains": lambda a, b: a in b, "___tuple_iterator_len": tuple_iterator_len, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -898,6 +904,15 @@ def get_guard_manager_from_source(self, source): ): assert base_guard_manager # to make mypy happy out = base_guard_manager + elif istype(source, TorchFunctionModeStackSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_torch_function_mode_stack_at( + source._get_index() + ), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -1064,6 +1079,14 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, AttrProxySource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: x.get_base(), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, TupleIteratorGetItemSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.tuple_iterator_getitem_manager( @@ -1083,13 +1106,20 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) - elif isinstance(source, WeakRefCallSource): + elif istype(source, WeakRefCallSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.weakref_call_manager( source=source_name, example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, CallFunctionNoArgsSource): + assert base_guard_manager # to make mypy happy + out = base_guard_manager.call_function_no_args_manager( + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" @@ -1464,12 +1494,12 @@ def EQUALS_MATCH(self, guard: Guard): ) if torch.distributed.is_available(): - from torch.distributed._tensor.placement_types import ( + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.tensor.placement_types import ( Partial, Replicate, Shard, ) - from torch.distributed.device_mesh import DeviceMesh ok_types = ok_types + ( Shard, @@ -2196,6 +2226,8 @@ def __init__( self.output_graph = output_graph w_builder = None + # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing + # in case a set default device call was made in the graph. self.torch_function_mode_stack = ( output_graph.torch_function_mode_stack if output_graph else None ) @@ -2229,9 +2261,16 @@ def cleanup_builder(weak_b): # Break retain cycle. See test_release_input_memory w_builder = weakref.ref(builder, cleanup_builder) + guard_on_nn_modules = config.guard_nn_modules and justknobs_check( + "pytorch/compiler:guard_nn_modules" + ) + + if not justknobs_check("pytorch/compiler:guard_nn_modules"): + log.warning("guard_nn_modules is turned off using justknobs killswitch") + for guard in sorted(guards or [], key=Guard.sort_key): if ( - not config.guard_nn_modules + not guard_on_nn_modules and guard.is_specialized_nn_module() # Default func args must be guarded on. # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API @@ -2305,15 +2344,12 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): ) if config.enable_cpp_guard_manager: - from .variables.torch_function import IGNORED_MODES - # Insert the global_state guard assert self.guard_manager # to make mypy happy self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) self.guard_manager.root.add_torch_function_mode_stack_guard( self.torch_function_mode_stack, - list(IGNORED_MODES), ["___check_torch_function_mode_stack()"], ) # Clear references to torch_function modes held in the list @@ -2620,16 +2656,14 @@ def is_recompiles_verbose_enabled(): # this will only be used if cpp guards are disabled def make_torch_function_mode_stack_guard(intial_stack): types = [type(x) for x in intial_stack] - from .variables.torch_function import IGNORED_MODES def check_torch_function_mode_stack(): cur_stack = get_torch_function_mode_stack() + if len(cur_stack) != len(types): return False for ty, mode in zip(types, cur_stack): - if ty in IGNORED_MODES: - continue if ty != type(mode): return False diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 5684de69601565..76be81a088c3c9 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -78,7 +78,6 @@ get_instruction_source_311, get_locals_to_steal, get_static_address_type, - get_torch_function_mode_stack, graph_break_reasons, increment_op_count, lazy_format_graph_code, @@ -250,6 +249,7 @@ def __init__( local_scope: Scope, global_scope: Scope, f_code, + torch_function_mode_stack, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] @@ -326,7 +326,7 @@ def __init__( ] = collections.defaultdict(list) # Stores the full fqn of a param or buffer to the relevant source. self.param_name_to_source: Optional[Dict[str, Source]] = {} - self.side_effects = SideEffects() + self.side_effects = SideEffects(self) # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL # and LOAD_ATTR for same python objects free. self.variable_tracker_cache = VariableTrackerCache() @@ -368,7 +368,7 @@ def __init__( # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled() # This records the initial torch function mode stack for guarding - self.torch_function_mode_stack = get_torch_function_mode_stack() + self.torch_function_mode_stack = torch_function_mode_stack # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager @@ -1020,7 +1020,7 @@ def append_prefix_insts(): prefix_insts.clear() for block in reversed(tx.block_stack): - block.exit(tx) + block.exit(tx, is_graph_break=reason.graph_break) self.cleanup_graph() tx.prune_dead_locals() @@ -1462,7 +1462,9 @@ def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: # aborting execution. raise e except Exception as e: - raise BackendCompilerFailed(self.compiler_fn, e) from e + raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( + e.__traceback__ + ) from None signpost_event( "dynamo", @@ -1832,6 +1834,14 @@ def __init__( # Dicts maintain the order of args for the HigherOrderOperator call. self.lifted_freevars = {} self.prev_inst = None + # True if this tracer is currently tracing into torch.utils.checkpoint + # as part of speculate_subgraph. + self.under_activation_checkpoint = False + # True if we want to allow side-effects (doesn't throw error on their existence) + # during this tracer's tracing of torch.utils.checkpoint (via speculate_subgraph). + # Only safe if we know for sure that *NOT* replaying these side-effects during + # backward recomputation of the checkpoint region doesn't affect its correctness. + self.allow_side_effects_under_checkpoint = False self._cur_code = None self._orig_gm_meta = None diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 9b465726754bbe..5b2812bc08c9e5 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -4,16 +4,51 @@ # NOTE: 1. Please do not import any submodule in the directory here to avoid circular imports. # 2. While adding a new polyfill module, also add it to POLYFILLED_MODULE_NAMES in loader.py. +# Add it in the TYPE_CHECKING block below as well. # mypy: allow-untyped-defs -import math -from typing import Any, Callable, Sequence +from typing import Any, Callable, Sequence, TYPE_CHECKING import torch +if TYPE_CHECKING: + # Load by torch._dynamo.polyfills.loader + # See also the POLYFILLED_MODULE_NAMES in torch/_dynamo/polyfills/loader.py + # Put the submodules here to avoid circular imports + from . import ( + builtins as builtins, + functools as functools, + itertools as itertools, + os as os, + sys as sys, + ) + +from torch.overrides import BaseTorchFunctionMode + + +# These classes handle support for TorchFunctionModes across +# graph breaks +# Today the TorchFunctionMode enter (for the classes we support) +# simply pushes the mode onto the stack. Since after this occurs +# the stack is mutated, and we replay these mutations, we don't need +# any cleanup logic to be run once the graph break occurs, we simply replay +# these mutations to ensure at the graph break the torch function mode stack is correct +# and reconstruct the torch function mode stack normally +# when we compile the resume function on the other side of the break. +# However, to ensure we exit properly +# in the resume function, we need to re-enter the contexts as we do other contexts. +# These contexts do nothing on enter, but provide the correct exit logic to ensure +# the stack state is correct. +class NoEnterTorchFunctionMode(BaseTorchFunctionMode): + def __enter__(self): + pass + + def index(iterator, item, start=0, end=None): + from itertools import islice + for i, elem in islice(enumerate(iterator), start, end): if item == elem: return i @@ -21,35 +56,14 @@ def index(iterator, item, start=0, end=None): raise ValueError(f"{item} is not in {type(iterator)}") -def islice(iterator, start=0, end=None, step=1): - if start < 0 or (end is not None and end < 0) or step < 0: - raise ValueError("Indices must be non-negative") - if step == 0: - raise ValueError("Step cannot be 0") - - it = iter(iterator) - - for _ in range(start): - next(it) - - if end is None: - for i, element in enumerate(it): - if i % step == 0: - yield element - else: - for i, element in enumerate(it): - if i % step == 0 and i + start < end - start: - yield element - elif i + start >= end - start: - break - - def repeat(item, count): for i in range(count): yield item def radians(x): + import math + return math.pi / 180.0 * x @@ -154,3 +168,19 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): if isinstance(obj, cls): obj.__init__(*args, **kwargs) return obj + + +def foreach_lerp_inplace(self, end, weight): + # decompose foreach lerp into constituent ops, prevents a graph break due to + # converting a value to a scalar when arg[2] is a single tensor + result = torch._foreach_sub(end, self) + result = torch._foreach_mul(result, weight) + return torch._foreach_add_(self, result) + + +def foreach_pow_scalar(scalar, exps): + return torch._foreach_pow([scalar for _ in exps], exps) + + +def addcmul_inplace(self, tensor1, tensor2, value): + return self.add_(tensor1 * tensor2 * value) diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index d9c3c2644c31ba..62305086b804fe 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -2,6 +2,8 @@ Python polyfills for builtins """ +from __future__ import annotations + import builtins import functools import operator @@ -13,6 +15,7 @@ __all__ = [ "all", "any", + "enumerate", "sum", ] @@ -20,7 +23,7 @@ _T = TypeVar("_T") -@substitute_in_graph(builtins.all) +@substitute_in_graph(builtins.all, can_constant_fold_through=True) def all(iterable: Iterable[object], /) -> bool: for elem in iterable: if not elem: @@ -28,7 +31,7 @@ def all(iterable: Iterable[object], /) -> bool: return True -@substitute_in_graph(builtins.any) +@substitute_in_graph(builtins.any, can_constant_fold_through=True) def any(iterable: Iterable[object], /) -> bool: for elem in iterable: if elem: @@ -36,6 +39,18 @@ def any(iterable: Iterable[object], /) -> bool: return False -@substitute_in_graph(builtins.sum) # type: ignore[arg-type] +@substitute_in_graph(builtins.enumerate, is_embedded_type=True) # type: ignore[arg-type] +def enumerate(iterable: Iterable[_T], start: int = 0) -> Iterable[tuple[int, _T]]: + if not isinstance(start, int): + raise TypeError( + f"{type(start).__name__!r} object cannot be interpreted as an integer" + ) + + for x in iterable: + yield start, x + start += 1 + + +@substitute_in_graph(builtins.sum, can_constant_fold_through=True) # type: ignore[arg-type] def sum(iterable: Iterable[_T], /, start: _T = 0) -> _T: # type: ignore[assignment] return functools.reduce(operator.add, iterable, start) diff --git a/torch/_dynamo/polyfills/itertools.py b/torch/_dynamo/polyfills/itertools.py index 7207acd6504b51..63266e85a2b8a1 100644 --- a/torch/_dynamo/polyfills/itertools.py +++ b/torch/_dynamo/polyfills/itertools.py @@ -2,21 +2,90 @@ Python polyfills for itertools """ +from __future__ import annotations + import itertools -from typing import Iterable, Iterator, Tuple, TypeVar +import sys +from typing import Iterable, Iterator, TypeVar from ..decorators import substitute_in_graph -__all__ = ["tee"] +__all__ = [ + "chain", + "chain_from_iterable", + "islice", + "tee", +] _T = TypeVar("_T") +# Reference: https://docs.python.org/3/library/itertools.html#itertools.chain +@substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type] +def chain(*iterables: Iterable[_T]) -> Iterator[_T]: + for iterable in iterables: + yield from iterable + + +@substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type] +def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]: + return itertools.chain(*iterable) + + +chain.from_iterable = chain_from_iterable # type: ignore[method-assign] + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.islice +@substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type] +def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]: + s = slice(*args) + start = 0 if s.start is None else s.start + stop = s.stop + step = 1 if s.step is None else s.step + if start < 0 or (stop is not None and stop < 0) or step <= 0: + raise ValueError( + "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.", + ) + + if stop is None: + # TODO: use indices = itertools.count() and merge implementation with the else branch + # when we support infinite iterators + next_i = start + for i, element in enumerate(iterable): + if i == next_i: + yield element + next_i += step + else: + indices = range(max(start, stop)) + next_i = start + for i, element in zip(indices, iterable): + if i == next_i: + yield element + next_i += step + + +# Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise +if sys.version_info >= (3, 10): + + @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type] + def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]: + a = None + first = True + for b in iterable: + if first: + first = False + else: + yield a, b # type: ignore[misc] + a = b + + __all__ += ["pairwise"] + + # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee @substitute_in_graph(itertools.tee) -def tee(iterable: Iterable[_T], n: int = 2, /) -> Tuple[Iterator[_T], ...]: +def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: iterator = iter(iterable) shared_link = [None, None] diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index cce1ba183ec695..24478e1b5a0f96 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -11,11 +11,13 @@ from types import ModuleType +# See also the TYPE_CHECKING block in torch/_dynamo/polyfills/__init__.py POLYFILLED_MODULE_NAMES: Tuple[str, ...] = ( "builtins", "functools", "itertools", "os", + "sys", ) POLYFILLED_MODULES: Tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/os.py b/torch/_dynamo/polyfills/os.py index 013cc81ac30e6f..5388816b826742 100644 --- a/torch/_dynamo/polyfills/os.py +++ b/torch/_dynamo/polyfills/os.py @@ -14,7 +14,7 @@ # Copied from os.py in the standard library -@substitute_in_graph(os.fspath) +@substitute_in_graph(os.fspath, can_constant_fold_through=True) def fspath(path: AnyStr | os.PathLike[AnyStr]) -> AnyStr: if isinstance(path, (str, bytes)): return path diff --git a/torch/_dynamo/polyfills/sys.py b/torch/_dynamo/polyfills/sys.py new file mode 100644 index 00000000000000..e9479eda8a0862 --- /dev/null +++ b/torch/_dynamo/polyfills/sys.py @@ -0,0 +1,6 @@ +""" +Python polyfills for sys +""" + + +__all__ = [] # type: ignore[var-annotated] diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index a21a20cdf2e611..19aa69d084db36 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -521,7 +521,7 @@ def repro_common(options, mod, load_args): def repro_minifier_query(options, mod, load_args): mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( - ACCURACY_FAILS[options.accuracy], check_str=options.check_str + ACCURACY_FAILS[options.accuracy], check_str=options.check_str # type: ignore[call-arg] ) if fail_fn(mod, args): sys.exit(1) diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index b1bb950d11c11b..214480b02c9332 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -12,6 +12,7 @@ import torch import torch.fx as fx +from torch._dynamo.backends.registry import CompiledFn from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -271,8 +272,10 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -@register_debug_backend -def dynamo_minifier_backend(gm, example_inputs, compiler_name): +@register_debug_backend # type: ignore[arg-type] +def dynamo_minifier_backend( + gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn +): from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -311,7 +314,7 @@ def dynamo_minifier_backend(gm, example_inputs, compiler_name): return gm -@register_debug_backend +@register_debug_backend # type: ignore[arg-type] def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): from functorch.compile import minifier diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 132e9e4081bceb..2013b3992eb54f 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple from .bytecode_transformation import ( + add_push_null, create_call_function, create_call_method, create_dup_top, @@ -48,6 +49,109 @@ class ReenterWith: stack_index: int target_values: Optional[Tuple[Any, ...]] = None + def try_except_torch_function_mode(self, code_options, cleanup: List[Instruction]): + """ + Codegen based off of: + try: + (rest) + except: + (restore previous stack) + + """ + from .variables.torch_function import get_prev_stack_var_name + + except_jump_target = create_instruction( + "NOP" if sys.version_info < (3, 11) else "PUSH_EXC_INFO" + ) + cleanup_complete_jump_target = create_instruction("NOP") + + setup_finally: List[Instruction] = [] + + if sys.version_info < (3, 11): + setup_finally.append( + create_instruction("SETUP_FINALLY", target=except_jump_target) + ) + else: + exn_tab_begin = create_instruction("NOP") + exn_tab_end = create_instruction("NOP") + exn_tab_begin.exn_tab_entry = InstructionExnTabEntry( + exn_tab_begin, + exn_tab_end, + except_jump_target, + self.stack_index + 1, + False, + ) + setup_finally.append(exn_tab_begin) + + def create_reset(): + insts = [ + create_instruction( + "LOAD_GLOBAL", argval="__import_torch_dot__dynamo_dot_utils" + ), + create_instruction("LOAD_ATTR", argval="set_torch_function_mode_stack"), + ] + add_push_null(insts) + return [ + *insts, + create_instruction("LOAD_FAST", argval=get_prev_stack_var_name()), + *create_call_function(1, False), + create_instruction("POP_TOP"), + ] + + if sys.version_info < (3, 9): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + *create_reset(), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + create_instruction("END_FINALLY"), + cleanup_complete_jump_target, + ] + elif sys.version_info < (3, 11): + epilogue = [ + create_instruction("POP_BLOCK"), + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + create_instruction("POP_TOP"), + *create_reset(), + create_instruction("RAISE_VARARGS", argval=0), + create_instruction("POP_EXCEPT", argval=0), + cleanup_complete_jump_target, + ] + else: + finally_exn_tab_end = create_instruction("RAISE_VARARGS", argval=0) + finally_exn_tab_target = create_instruction("COPY", arg=3) + except_jump_target.exn_tab_entry = InstructionExnTabEntry( + except_jump_target, + finally_exn_tab_end, + finally_exn_tab_target, + self.stack_index + 2, + True, + ) + epilogue = [ + exn_tab_end, + create_instruction("JUMP_FORWARD", target=cleanup_complete_jump_target), + except_jump_target, # PUSH_EXC_INFO + create_instruction("POP_TOP"), + *create_reset(), + finally_exn_tab_end, + finally_exn_tab_target, # COPY 3 + create_instruction("POP_EXCEPT"), + create_instruction("RERAISE", arg=1), # RERAISE 1 + cleanup_complete_jump_target, + ] + + cleanup[:] = epilogue + cleanup + return setup_finally + # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol def try_except(self, code_options, cleanup: List[Instruction]): @@ -367,14 +471,14 @@ def generate( code, lineno, offset: int, - setup_fn_target_offsets: Tuple[int], # only used in Python 3.11+ + setup_fn_target_offsets: Tuple[int, ...], # only used in Python 3.11+ nstack: int, - argnames: Tuple[str], - argnames_null: Tuple[str], - setup_fns: Tuple[ReenterWith], - stack_ctx_vars: Tuple[int, Tuple[Any]], - argnames_ctx_vars: Tuple[str, Tuple[Any]], - null_idxes: Tuple[int], + argnames: Tuple[str, ...], + argnames_null: Tuple[str, ...], + setup_fns: Tuple[ReenterWith, ...], + stack_ctx_vars: Tuple[Tuple[int, Tuple[Any]], ...], + argnames_ctx_vars: Tuple[Tuple[str, Tuple[Any]], ...], + null_idxes: Tuple[int, ...], ) -> types.CodeType: assert offset is not None assert not ( diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 0f8ad871425dae..01891dc9e196db 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs +import contextlib import functools import inspect import warnings +import weakref from collections.abc import MutableMapping from typing import Any, Dict, List, Optional, Type, Union @@ -17,13 +19,14 @@ from .codegen import PyCodegen from .exc import unimplemented from .source import GlobalSource, LocalSource, Source -from .utils import nn_module_new, object_new +from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( is_side_effect_safe, MutableLocalBase, MutableLocalSource, VariableTracker, ) +from .variables.user_defined import FrozenDataClassVariable class MutableSideEffects(MutableLocalBase): @@ -78,6 +81,7 @@ class SideEffects: def __init__( self, + output_graph, id_to_variable=None, store_attr_mutations=None, keepalive=None, @@ -85,6 +89,7 @@ def __init__( tensor_hooks=None, ): super().__init__() + self.output_graph_weakref = weakref.ref(output_graph) self.id_to_variable = id_to_variable or {} self.store_attr_mutations = store_attr_mutations or {} self.keepalive = keepalive or [] @@ -129,6 +134,7 @@ def diff(self, other: "SideEffects") -> Optional[str]: def clone(self): """Create a shallow copy""" return self.__class__( + output_graph=self.output_graph_weakref(), id_to_variable=dict(self.id_to_variable), store_attr_mutations={ k: dict(v) for k, v in self.store_attr_mutations.items() @@ -144,6 +150,14 @@ def __contains__(self, item): def __getitem__(self, item): return self.id_to_variable[id(item)] + def should_allow_side_effects_under_checkpoint(self): + output_graph = self.output_graph_weakref() + return ( + output_graph + and output_graph.current_tx.output.current_tracer.under_activation_checkpoint + and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint + ) + def check_allowed_side_effect(self, item): from torch._dynamo.variables.misc import AutogradFunctionContextVariable @@ -151,6 +165,8 @@ def check_allowed_side_effect(self, item): # These are benign. if isinstance(item, AutogradFunctionContextVariable): return True + if self.should_allow_side_effects_under_checkpoint(): + return True if not is_side_effect_safe(item.mutable_local): unimplemented( "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" @@ -285,6 +301,8 @@ def track_object_new_from_user_defined_class( variable_cls = variables.UnspecializedNNModuleVariable elif issubclass(user_cls, MutableMapping): variable_cls = variables.MutableMappingVariable + elif is_frozen_dataclass(user_cls): + variable_cls = FrozenDataClassVariable else: variable_cls = variables.UserDefinedObjectVariable @@ -590,11 +608,22 @@ def codegen_update_mutated(self, cg: PyCodegen): elif isinstance( var, variables.torch_function.TorchFunctionModeStackVariable ): + # Needed in the finally block for stack restoration + cg.add_push_null( + lambda: cg.load_import_from( + utils.__name__, "get_torch_function_mode_stack" + ) + ) + cg.call_function(0, False) + name = variables.torch_function.get_prev_stack_var_name() + cg.code_options["co_varnames"] += (name,) + cg.append_output(create_instruction("STORE_FAST", argval=name)) cg.add_push_null( lambda: cg.load_import_from( utils.__name__, "set_torch_function_mode_stack" ) ) + cg.foreach(var.symbolic_stack) cg.append_output( create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) @@ -696,3 +725,14 @@ def is_empty(self): def clear(self): self.keepalive.clear() self.id_to_variable.clear() + + +@contextlib.contextmanager +def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): # type: ignore[name-defined] # noqa: F821 + assert tx.output.current_tracer.under_activation_checkpoint + orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint + try: + tx.output.current_tracer.allow_side_effects_under_checkpoint = True + yield + finally: + tx.output.current_tracer.allow_side_effects_under_checkpoint = orig_val diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 7a3ad7a9b05ca8..9febc69f42cda2 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -25,6 +25,8 @@ GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE, GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, } # represents nn.Modules tracked with UnspecializedNNModuleVariable @@ -39,6 +41,8 @@ # Just to ensure that guard_source() works GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, } # represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable @@ -52,6 +56,8 @@ # Just to ensure that guard_source() works GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE, + GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, + GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, } _GUARD_SOURCE_FSDP_MODULE = { @@ -185,6 +191,11 @@ def name(self): return f"{self.base.name()}()" +@dataclasses.dataclass(frozen=True) +class CallFunctionNoArgsSource(WeakRefCallSource): + pass + + @dataclasses.dataclass(frozen=True) class AttrSource(ChainedSource): member: str @@ -261,7 +272,7 @@ def guard_source(self): def name(self): return f"" - def make_guard(self): + def make_guard(self, fn): raise NotImplementedError def is_ephemeral(self): @@ -382,6 +393,17 @@ def name(self): return f"{self.base.name()}._type().qualified_name()" +class AttrProxySource(ChainedSource): + def reconstruct(self, codegen): + self.base.reconstruct(codegen) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"{self.base.name()}.get_base()" + + @dataclasses.dataclass(frozen=True) class DefaultsSource(ChainedSource): idx_key: Union[int, str] @@ -597,7 +619,7 @@ class TorchFunctionModeStackSource(Source): ind: int def name(self): - return "" + return f"___get_torch_function_mode_stack_at({self._get_index()})" def _get_index(self): from .variables.torch_function import TorchFunctionModeStackVariable diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 76660b417af882..4a83b455a1252b 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -19,20 +19,7 @@ import types import typing import weakref -from typing import ( - Any, - Callable, - cast, - Deque, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from unittest.mock import patch import torch @@ -72,14 +59,12 @@ GlobalWeakRefSource, LocalSource, Source, - TorchFunctionModeStackSource, ) from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( counters, get_fake_value, get_instruction_source_311, - get_torch_function_mode_stack, graph_break_dup_warning_checker, istype, LazyString, @@ -120,11 +105,10 @@ ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable - - -if TYPE_CHECKING: - from .variables.torch_function import TorchFunctionModeVariable - +from .variables.torch_function import ( + SymbolicTorchFunctionState, + TorchFunctionModeVariable, +) from .variables.user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, @@ -283,9 +267,12 @@ def resume_fn(self): else: return ReenterWith(self.stack_index) - def exit(self, tx): + def exit(self, tx, is_graph_break): assert self.with_context is not None - return self.with_context.exit(tx) + if ( + is_graph_break and self.with_context.exit_on_graph_break() + ) or not is_graph_break: + return self.with_context.exit(tx) class ReturnValueOp(Exception): @@ -651,8 +638,17 @@ def handle_graph_break( cleanup: List[Instruction] = [] # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: + # Don't exit any modes we have entered, + # output bytecode will mutate the tf mode stack accordingly + if isinstance(b.with_context, TorchFunctionModeVariable): + cg.extend_output( + b.resume_fn().try_except_torch_function_mode( + cg.code_options, cleanup + ) + ) + continue assert b.with_context is not None - assert isinstance(b.with_context, ContextWrappingVariable) + assert isinstance(b.with_context, (ContextWrappingVariable)) b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) @@ -728,7 +724,7 @@ class InstructionTranslatorBase( output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"] + symbolic_torch_function_state: SymbolicTorchFunctionState stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction @@ -1376,9 +1372,8 @@ def _raise_exception_variable(self, inst): # 2) when user raises exception instance if isinstance(val, variables.ExceptionVariable): - if val.exc_type is StopIteration: - # StopIteration is used to find the end of iteration while tracing __next__ - raise exc.ObservedUserStopIteration(f"raised exception {val}") + if observed_exception_type := exc.observed_exception_map.get(val.exc_type): + raise observed_exception_type(f"raised exception {val}") raise exc.ObservedException(f"raised exception {val}") unimplemented(f"raise {exc}") @@ -1664,8 +1659,8 @@ def CALL_FUNCTION_EX(self, inst): if not isinstance( argsvars, BaseListVariable - ) and argsvars.has_unpack_var_sequence(self): - argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) + ) and argsvars.has_force_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) # Unpack for cases like fn(**obj) where obj is a map if isinstance(kwargsvars, UserDefinedObjectVariable): @@ -1834,7 +1829,7 @@ def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): items = [] for seq in seqs: try: - items.extend(seq.unpack_var_sequence(self)) + items.extend(seq.force_unpack_var_sequence(self)) except NotImplementedError: unimplemented(f"BUILD_LIST_UNPACK {seq}") self.push(cls(items, mutable_local=MutableLocal())) @@ -1872,7 +1867,7 @@ def BUILD_CONST_KEY_MAP(self, inst): assert isinstance(keys, TupleVariable) assert keys.is_python_constant() - keys = keys.unpack_var_sequence(self) + keys = keys.force_unpack_var_sequence(self) assert len(keys) == len(values) self.push( @@ -1962,8 +1957,8 @@ def UNPACK_SEQUENCE(self, inst): # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] - elif seq.has_unpack_var_sequence(self): - val = seq.unpack_var_sequence(self) + elif seq.has_force_unpack_var_sequence(self): + val = seq.force_unpack_var_sequence(self) else: unimplemented(f"UNPACK_SEQUENCE {seq}") if len(val) != inst.argval: @@ -1976,8 +1971,8 @@ def UNPACK_EX(self, inst): prefix = inst.argval & 0xFF # low byte suffix = inst.argval >> 8 # high byte seq = self.pop() - if seq.has_unpack_var_sequence(self): - vals = list(seq.unpack_var_sequence(self)) + if seq.has_force_unpack_var_sequence(self): + vals = list(seq.force_unpack_var_sequence(self)) assert len(vals) >= prefix + suffix vals_prefix = vals[:prefix] vals_list = vals[prefix : len(vals) - suffix] @@ -2306,7 +2301,10 @@ def setup_or_before_with(self, inst): ): unimplemented(f"{inst.opname} {ctx}") - if isinstance(ctx, GenericContextWrappingVariable): + if ( + isinstance(ctx, GenericContextWrappingVariable) + and not ctx.supports_graph_breaks() + ): self.generic_context_manager_depth += 1 # Need this redundant check for mypy @@ -2401,7 +2399,7 @@ def CALL_INTRINSIC_1(self, inst): self.UNARY_POSITIVE(inst) elif inst.argval == 6: # INTRINSIC_LIST_TO_TUPLE - self.push(TupleVariable(self.pop().unpack_var_sequence(self))) + self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) else: unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") @@ -2549,7 +2547,7 @@ def __init__( code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, f_code: types.CodeType, export: bool, inline_depth: int, @@ -2564,7 +2562,7 @@ def __init__( self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals - self.symbolic_torch_function_mode_stack = symbolic_torch_function_mode_stack + self.symbolic_torch_function_state = symbolic_torch_function_state self.stack = [] # stack of variable names for tracking 3.13 closures self.name_stack: list[Any] = [] @@ -2653,6 +2651,7 @@ def __init__( f_locals, f_globals, f_builtins, + torch_function_mode_stack, code_options, compiler_fn, one_graph, @@ -2678,6 +2677,7 @@ def __init__( local_scope=f_locals, global_scope=f_globals, f_code=f_code, + torch_function_mode_stack=torch_function_mode_stack, ), instructions=instructions, f_locals=f_locals, @@ -2687,7 +2687,7 @@ def __init__( symbolic_locals={}, # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, - symbolic_torch_function_mode_stack=collections.deque(), + symbolic_torch_function_state=None, # type: ignore[arg-type] # set below f_code=f_code, export=export, inline_depth=0, @@ -2722,7 +2722,9 @@ def __init__( if k in f_locals } - self._init_torch_function_mode_stack() + self.symbolic_torch_function_state = SymbolicTorchFunctionState( + torch_function_mode_stack + ) self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] if export: @@ -2763,29 +2765,6 @@ def _throw_if_in_functorch(self): ) unimplemented(msg) - def _init_torch_function_mode_stack(self): - from .variables.torch_function import TorchFunctionModeStackVariable - - TorchFunctionModeStackVariable.reset() - - self.symbolic_torch_function_mode_stack: Deque[ - TorchFunctionModeVariable - ] = collections.deque() - # We want to retrieve all modes to properly reconstruct the stack if needed - py_stack = get_torch_function_mode_stack(filter_ignored=False) - - if py_stack: - has_device_context = isinstance( - py_stack[0], torch.utils._device.DeviceContext - ) - - for i, val in enumerate(py_stack): - self.symbolic_torch_function_mode_stack.append( - variables.LazyVariableTracker.create( - val, source=TorchFunctionModeStackSource(i) - ) - ) - def get_example_value(self, source: Source): if isinstance(source, LocalSource): return self.f_locals[source.local_name] @@ -3117,7 +3096,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3127,7 +3106,7 @@ def get_trace_call_log_str(): code, sub_locals, parent.symbolic_globals, - parent.symbolic_torch_function_mode_stack, + parent.symbolic_torch_function_state, closure_cells, func, ) @@ -3180,7 +3159,7 @@ def __init__( code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], - symbolic_torch_function_mode_stack: Deque["TorchFunctionModeVariable"], + symbolic_torch_function_state: SymbolicTorchFunctionState, closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ) -> None: @@ -3197,7 +3176,7 @@ def __init__( f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, - symbolic_torch_function_mode_stack=symbolic_torch_function_mode_stack, + symbolic_torch_function_state=symbolic_torch_function_state, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, @@ -3219,7 +3198,7 @@ def fake_mode(self): def run_ctx_mgr(self): return TracingContext.current_frame(self.parent.frame_summary()) - def STORE_DEREF(self, inst): + def STORE_DEREF(self, inst): # type: ignore[override] if inst.argval in self.closure_cells: cell = self.closure_cells[inst.argval] val = self.pop() diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index 4736c75785ccb7..b05542d578f43b 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -15,6 +15,7 @@ import torch import torch._dynamo import torch._dynamo.test_case +from torch._dynamo.trace_rules import _as_posix_path from torch.utils._traceback import report_compile_source_on_error @@ -107,8 +108,9 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): log = logging.getLogger("torch._dynamo") log.addHandler(log_handler) try: - prev_cwd = os.getcwd() + prev_cwd = _as_posix_path(os.getcwd()) if cwd is not None: + cwd = _as_posix_path(cwd) os.chdir(cwd) with patch("sys.argv", args), report_compile_source_on_error(): exec(code, {"__name__": "__main__", "__compile_source__": code}) @@ -135,6 +137,8 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): stderr.getvalue().encode("utf-8"), ) else: + if cwd is not None: + cwd = _as_posix_path(cwd) return subprocess.run(args, capture_output=True, cwd=cwd, check=False) # Run `code` in a separate python process. @@ -157,7 +161,7 @@ def _run_test_code(self, code, *, isolate): # Runs the minifier launcher script in `repro_dir` def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): self.assertIsNotNone(repro_dir) - launch_file = os.path.join(repro_dir, "minifier_launcher.py") + launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) with open(launch_file) as f: launch_code = f.read() self.assertTrue(os.path.exists(launch_file)) @@ -176,7 +180,7 @@ def _run_minifier_launcher(self, repro_dir, isolate, *, minifier_args=()): # Runs the repro script in `repro_dir` def _run_repro(self, repro_dir, *, isolate=True): self.assertIsNotNone(repro_dir) - repro_file = os.path.join(repro_dir, "repro.py") + repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) with open(repro_file) as f: repro_code = f.read() self.assertTrue(os.path.exists(repro_file)) @@ -196,11 +200,11 @@ def _gen_test_code(self, run_code, repro_after, repro_level): return f"""\ import torch import torch._dynamo -{torch._dynamo.config.codegen_config()} -{torch._inductor.config.codegen_config()} +{_as_posix_path(torch._dynamo.config.codegen_config())} +{_as_posix_path(torch._inductor.config.codegen_config())} torch._dynamo.config.repro_after = "{repro_after}" torch._dynamo.config.repro_level = {repro_level} -torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}" {run_code} """ diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 6c5f47067547e9..704a3889707233 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -163,6 +163,7 @@ def insert_nops(instructions, code_options): local_scope=locals(), global_scope=globals(), f_code=frame.f_code, + torch_function_mode_stack=[], ) return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) @@ -372,6 +373,13 @@ def skipIfPy312(fn): return fn +def requiresPy310(fn): + if sys.version_info >= (3, 10): + return fn + else: + unittest.skip(fn) + + # Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py # and test/dynamo/test_dynamic_shapes.py def expectedFailureDynamic(fn): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index ccb29ab13b1479..44e1d662efe4af 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -12,7 +12,6 @@ import functools import importlib import inspect -import itertools import linecache import logging import multiprocessing @@ -34,7 +33,7 @@ import weakref from collections import defaultdict from pathlib import Path -from typing import Any, Callable, cast, Dict, List, Optional, Set, Union +from typing import Any, Callable, cast, Dict, List, Optional, Set, Type, Union import torch import torch._inductor.test_operators @@ -145,7 +144,7 @@ "torch.distributed.is_initialized": TorchInGraphFunctionVariable, "torch.distributed.get_rank": TorchInGraphFunctionVariable, "torch.distributed.get_world_size": TorchInGraphFunctionVariable, - "torch.distributed._tensor.api.DTensor#from_local": TorchInGraphFunctionVariable, + "torch.distributed.tensor._api.DTensor#from_local": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_size_by_name": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._resolve_group_name_by_ranks_and_tag": TorchInGraphFunctionVariable, "torch.distributed.distributed_c10d._get_group_tag": TorchInGraphFunctionVariable, @@ -179,8 +178,9 @@ "torch.nn.Parameter": TorchInGraphFunctionVariable, "torch.nn.Buffer": TorchInGraphFunctionVariable, "torch._nested_tensor_from_mask": SkipFunctionVariable, - "torch._nested_from_padded": SkipFunctionVariable, + "torch.nested._internal.nested_tensor.nested_from_padded": TorchInGraphFunctionVariable, "torch.nested.nested_tensor_from_jagged": UserFunctionVariable, + "torch.nested.nested_tensor_from_padded": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, @@ -192,7 +192,7 @@ "torch.Tensor#_make_wrapper_subclass": SkipFunctionVariable, "torch.Tensor#__init__": SkipFunctionVariable, "torch.cuda.set_device": SkipFunctionVariable, - "torch.cuda.current_device": SkipFunctionVariable, + "torch.cuda.current_device": TorchInGraphFunctionVariable, "torch._C.autocast_decrement_nesting": SkipFunctionVariable, "torch._C.autocast_increment_nesting": SkipFunctionVariable, "torch.autograd.grad": SkipFunctionVariable, @@ -304,6 +304,7 @@ "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, + "torch.set_default_device": UserFunctionVariable, "torch.sparse_bsc_tensor": SkipFunctionVariable, "torch.sparse_bsr_tensor": SkipFunctionVariable, "torch.sparse_csc_tensor": SkipFunctionVariable, @@ -414,10 +415,11 @@ "torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata", "torch._C._construct_storage_from_data_pointer", "torch._C._conv_determine_backend_memory_format", - "torch._C._cpu._is_cpu_support_avx2", - "torch._C._cpu._is_cpu_support_avx512", - "torch._C._cpu._is_cpu_support_avx512_vnni", - "torch._C._cpu._is_cpu_support_amx_tile", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._init_amx", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", @@ -1525,6 +1527,7 @@ "torch._neg_view_copy", "torch._neg_view", "torch._nested_from_padded_and_nested_example", + "torch._nested_from_padded_tensor", "torch._nested_tensor_from_mask_left_aligned", "torch._nested_tensor_from_tensor_list", "torch._nested_tensor_softmax_with_shape", @@ -2432,10 +2435,11 @@ "torch.chain_matmul", "torch.compile", "torch.compiled_with_cxx11_abi", - "torch.cpu._is_cpu_support_avx2", - "torch.cpu._is_cpu_support_avx512", - "torch.cpu._is_cpu_support_avx512_vnni", - "torch.cpu._is_cpu_support_amx_tile", + "torch._C._cpu._is_avx2_supported", + "torch._C._cpu._is_avx512_supported", + "torch._C._cpu._is_avx512_vnni_supported", + "torch._C._cpu._is_avx512_bf16_supported", + "torch._C._cpu._is_amx_tile_supported", "torch.cpu._init_amx", "torch.cpu.current_device", "torch.cpu.current_stream", @@ -2794,7 +2798,6 @@ "torch.random.initial_seed", "torch.random.seed", "torch.return_types.pytree_register_structseq", - "torch.set_default_device", "torch.set_default_dtype", "torch.set_default_tensor_type", "torch.set_deterministic_debug_mode", @@ -2847,8 +2850,8 @@ @functools.lru_cache(None) -def get_torch_obj_rule_map(): - d: Dict[Any, VariableTracker] = {} +def get_torch_obj_rule_map() -> Dict[Any, Type["VariableTracker"]]: + d: Dict[Any, Type[VariableTracker]] = {} for m in torch_name_rule_map: for k, v in m.items(): # type: ignore[attr-defined] if ".py#" not in k: @@ -2993,9 +2996,6 @@ def _builtin_function_ids() -> Dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update( - {id(v): f"itertools.{v.__name__}" for v in (itertools.chain, itertools.islice)} - ) rv.update( { id(cast): "typing.cast", @@ -3005,6 +3005,12 @@ def _builtin_function_ids() -> Dict[int, str]: return rv +@FunctionIdSet +def _polyfilled_function_ids() -> Set[int]: + # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids + return set() + + @FunctionIdSet def _numpy_function_ids() -> Dict[int, str]: rv = {} @@ -3080,6 +3086,11 @@ def is_builtin_constant(obj) -> bool: return id(obj) in _builtin_constant_ids +def is_polyfilled_callable(obj) -> bool: + # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids + return id(obj) in _polyfilled_function_ids + + def is_numpy(obj) -> bool: if np is None: return False @@ -3189,6 +3200,7 @@ def _module_dir(m: types.ModuleType): "torch._higher_order_ops.cond", "torch._higher_order_ops.while_loop", "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.scan", "torch.nn.attention.flex_attention", "torch.ao.quantization.pt2e.export_utils", "torch.ao.quantization.pt2e.qat_utils", @@ -3201,8 +3213,8 @@ def _module_dir(m: types.ModuleType): if torch.distributed.is_available(): LEGACY_MOD_INLINELIST |= { - "torch.distributed._tensor.api", - "torch.distributed._tensor.device_mesh", + "torch.distributed.tensor._api", + "torch.distributed.tensor.device_mesh", "torch.distributed.device_mesh", "torch.distributed.algorithms._checkpoint.checkpoint_wrapper", "torch.distributed.tensor.parallel._data_parallel_utils", @@ -3211,26 +3223,33 @@ def _module_dir(m: types.ModuleType): # we have to add replicate to LEGACY_MOD_INLINELIST to ensure # the forward_hook won't be ignored. "torch.distributed._composable.replicate", - "torch.distributed._composable.fsdp", } + if not torch._dynamo.config.skip_fsdp_hooks: + LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp") # Force inline functions under these modules, even they are in *_SKIPLIST. # We are using python module name instead of file or directory object to avoid circular dependency. # Please keep this sorted alphabetically. -MOD_INLINELIST = { - "torch.utils._python_dispatch", - "torch._refs", - "torch._prims", +MOD_INLINELIST = [ "torch._decomp", "torch._dynamo._trace_wrapped_higher_order_op", "torch._dynamo.comptime", "torch._dynamo.polyfills", - "torch._functorch.vmap", "torch._functorch.autograd_function", - "torch._library.custom_ops", "torch._functorch.eager_transforms", + "torch._functorch.functional_call", + "torch._functorch.vmap", + "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.scan", + "torch._higher_order_ops.strict_mode", + "torch._higher_order_ops.while_loop", "torch._inductor.test_operators", + "torch._library.autograd", + "torch._library.custom_ops", + "torch._prims", + "torch._refs", + "torch._tensor", "torch.amp.autocast_mode", "torch.ao.nn", "torch.autograd.function", @@ -3239,31 +3258,30 @@ def _module_dir(m: types.ModuleType): "torch.distributions", "torch.export._tree_utils", "torch.fx._pytree", + "torch.fx._symbolic_trace", + "torch.fx.experimental.proxy_tensor", "torch.fx.passes.shape_prop", "torch.nn", "torch.overrides", "torch.random", "torch.sparse", "torch.testing", - "torch.testing._internal.hypothesis_utils", "torch.utils._content_store", "torch.utils._contextlib", + "torch.utils._device", "torch.utils._foreach_utils", + "torch.utils._python_dispatch", "torch.utils._pytree", "torch.utils.hooks", - "torch._tensor", - "torch._higher_order_ops.strict_mode", - "torch._higher_order_ops.while_loop", - "torch._higher_order_ops.associative_scan", - "torch._functorch.functional_call", -} +] +assert sorted(set(MOD_INLINELIST)) == MOD_INLINELIST +MOD_INLINELIST = set(MOD_INLINELIST) if torch.distributed.is_available(): MOD_INLINELIST.add("torch.distributed") - MOD_INLINELIST.add("torch.distributed._functional_collectives") - MOD_INLINELIST.add("torch.distributed._composable.replicate") - MOD_INLINELIST.add("torch.distributed._composable.fsdp") + if not torch._dynamo.config.skip_fsdp_hooks: + MOD_INLINELIST.add("torch.distributed._composable.fsdp") @functools.lru_cache(None) @@ -3470,9 +3488,7 @@ def check_verbose(obj, is_inlined_call=False): # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: Set[str] = set() - rule = torch._dynamo.trace_rules.lookup_inner( - fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons - ) + rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) if issubclass(rule, (UserFunctionVariable, PolyfilledFunctionVariable)): return SkipResult( False, @@ -3527,6 +3543,8 @@ def lookup_callable(obj): return SkipFunctionVariable if is_callable_allowed(obj): return TorchInGraphFunctionVariable + if is_polyfilled_callable(obj): + return PolyfilledFunctionVariable if is_builtin_callable(obj): return BuiltinVariable return None @@ -3587,7 +3605,9 @@ def lookup_inner( if reasons is not None: reasons.add("func name is patched_init") return SkipFunctionVariable - elif name == "__torch_function__": + elif name == "__torch_function__" or ( + obj and obj.__name__ == "__torch_function__" + ): if reasons is not None: reasons.add("func name is __torch_function__") return UserFunctionVariable diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c8b2ac7a953af5..7d34671b11a3c0 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -21,7 +21,6 @@ import os import re import sys -import textwrap import threading import time import types @@ -30,6 +29,7 @@ import warnings import weakref from contextlib import contextmanager +from dataclasses import is_dataclass from functools import lru_cache from types import MethodWrapperType from typing import ( @@ -63,7 +63,7 @@ import torch.utils._pytree as pytree from torch import fx from torch._C import ( - _get_function_stack_at, + _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, _push_on_torch_function_stack, @@ -219,6 +219,21 @@ def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: frame_phase_timing[key][phase_name] += time_spent +# Use frame_phase_timing to record remote_cache_time_saved +# This follows the same principles of key as the other frame phase timings, +# but is incremented by FxGraphCache (and later AOTAutogradCache) directly +def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None: + key = None + if is_backward: + # Use compile id as the frame key for backwards compilation + key = str(torch._guards.CompileContext.current_compile_id()) + else: + key = str(curr_frame) + # Convert to seconds (as a float) + time_saved = time_saved_ns / 1e9 + _add_time_spent(key, "remote_cache_time_saved", time_saved) + + def get_cache_stats() -> Dict[str, Any]: """Get a bunch of metadata about cache hits and misses to use in chromium events""" cache_stats = { @@ -331,15 +346,24 @@ def dynamo_timed( code_gen_time = frame_phase_timing[compile_id].get( "code_gen", None ) + remote_cache_time_saved = frame_phase_timing[ + compile_id + ].get("remote_cache_time_saved", None) else: inductor_compile_time = None code_gen_time = None + remote_cache_time_saved = None + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) metrics = BwdCompilationMetrics( compile_id, inductor_compile_time, code_gen_time, fail_type, fail_reason, + remote_cache_time_saved, + structured_logging_overhead_s, ) record_compilation_metrics(metrics) @@ -778,6 +802,8 @@ class CompilationMetrics: # a compiled frame has_guarded_code: bool possibly_missed_reinplacing_opportunities: Optional[int] + remote_cache_time_saved_s: Optional[float] + structured_logging_overhead_s: Optional[float] @dataclasses.dataclass @@ -787,6 +813,8 @@ class BwdCompilationMetrics: code_gen_time_s: Optional[float] fail_type: Optional[str] fail_reason: Optional[str] + remote_cache_time_saved_s: Optional[float] + structured_logging_overhead_s: Optional[float] DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -812,6 +840,11 @@ def record_compilation_metrics( k: list(v) if isinstance(v, set) else v for k, v in dataclasses.asdict(compilation_metrics).items() }, + # NB: Because compilation metrics *includes* the logging overhead time, + # we can't both *measure* the logging overhead of compilation metrics + # without making it inconsistent with compilation metrics itself, so + # we ignore the (hopefully small) time spent logging compilation metrics + record_logging_overhead=False, ) if config.log_compilation_metrics: log_compilation_event(compilation_metrics) @@ -1607,8 +1640,12 @@ def same( """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref - if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): - assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" + if isinstance( + ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size) + ): + assert isinstance( + res, (list, tuple, collections.deque) + ), f"type mismatch {type(ref)} {type(res)}" if len(ref) != len(res): log_error("Length mismatch") return False @@ -1727,6 +1764,26 @@ def to_tensor(t): # Check error from fp64 version if fp64_ref.dtype == torch.float64: + # Fix a corner case that res and fp64_ref does not contains NaN and match (with loose tolerance) + # while the ref contains NaN. In this case, RMSE should not match any ways. + # But res is 'BETTER' than ref so we count it pass. + # + # This happens for Super_SloMo when loop ordering after fusion is enabled: + # https://gist.github.com/shunting314/11f235c70f7db0d52718d26f4a701cab + loose_tol = 1e-2 * 4 + if ( + not fp64_ref.isnan().any() + and not res.isnan().any() + and ref.isnan().any() + and torch.allclose( + fp64_ref.to(dtype=res.dtype), + res, + atol=loose_tol, + rtol=loose_tol, + equal_nan=equal_nan, + ) + ): + return True ref_error = rmse(fp64_ref, ref).item() # ref unable to produce this with stable numerics in this precision, ignore if math.isnan(ref_error): @@ -1741,7 +1798,9 @@ def to_tensor(t): # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%. # Thus, it's possible that the correctness check failures for these models are # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms. - multiplier = 3.0 if res.dtype == torch.bfloat16 else 2.0 + multiplier = ( + 3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0 + ) if use_larger_multiplier_for_smaller_tensor and ( fp64_ref.numel() <= 10 and tol >= 4 * 1e-2 @@ -1889,104 +1948,6 @@ def disable_cache_limit(): seen_code_map = ExactWeakKeyDictionary() -class CompileProfiler: - """Utility for profiling how and what dynamo would compile. - - Can be used for - * diagnosing recompilation issues - * determining an appropriate compile cache limit - * (TODO)confirming which functions got compiled/skipped - """ - - def __init__(self): - self.frame_count = 0 - self.op_count = 0 - self.backend_ctx_ctor = disable_cache_limit - - def __call__(self, gm: torch.fx.GraphModule, example_inputs): - self.frame_count += 1 - for node in gm.graph.nodes: - if "call" in node.op: - self.op_count += 1 - return gm.forward - - # no-op __enter__ and __exit__ to preserve BC - def __enter__(self): - return self - - def __exit__(self, typ, val, traceback): - pass - - def get_metrics(self): - return {"guard_failures": guard_failures} - - def report(self): - metrics = self.get_metrics() - gf = metrics["guard_failures"] - - def num_recompiles(code): - return len(gf[code]) - - def recompile_reasons(code): - return "\n".join([str(x) for x in gf[code]]) - - summarized_gf = [ - [format_func_info(code), num_recompiles(code), recompile_reasons(code)] - for code in gf - ] - - def graph_break_report(): - if "graph_break" in counters: - graph_breaks = counters["graph_break"] - return tabulate( - [[msg, graph_breaks[msg]] for msg in graph_breaks], - headers=["Graph Break Reason", "Count"], - ) - - def recompilation_report(): - if len(gf): - max_recompiles = max(num_recompiles(code) for code in gf) - recomp_table = tabulate( - summarized_gf, - headers=["Function", "Recompiles", "Recompile Reasons"], - ) - return recomp_table + textwrap.dedent( - f""" - - Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited. - """ - ) - - report = textwrap.dedent( - """ - Torchdynamo Profiler Report - =========================== - - Graph Breaks - ------------ - Graph breaks happen when torchdynamo encounters code it can't safely trace. - If you want to find out why breaks are happening, check below for each break reason - You may gain additional insight by passing `fullgraph=True` to torch.compile, - to stop at the first break. - - """ - ) - report += graph_break_report() or "No graph breaks detected." - report += textwrap.dedent( - """ - - Recompilation - ------------- - These subgraphs were recompiled more than once due to guard failures - Guard failures indicate some condition assumed to be static by the tracer changed, - making it unsafe to reuse the compiled program. - - """ - ) - report += recompilation_report() or "No recompilation detected.\n" - return report - - # return same dir unless user changes config between calls @functools.lru_cache(None) def _get_debug_dir(root_dir): @@ -2142,10 +2103,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): ): raise UserError( # noqa: B904 UserErrorType.CONSTRAINT_VIOLATION, - "Tried to use data-dependent value in the subsequent computation. " - "This can happen when we encounter unbounded dynamic value that is unknown during tracing time. " - "You will need to explicitly give hint to the compiler. Please take a look at " - f"torch._check OR torch._check_is_size APIs. {cause}", + str(cause), case_name="constrain_as_size_example", ) elif isinstance(cause, ValueRangeError): @@ -2312,9 +2270,13 @@ def import_submodule(mod: types.ModuleType): def object_has_getattribute(value: Any): + return class_has_getattribute(type(value)) + + +def class_has_getattribute(cls: type): try: if isinstance( - inspect.getattr_static(type(value), "__getattribute__"), + inspect.getattr_static(cls, "__getattribute__"), types.FunctionType, ): return True @@ -2960,6 +2922,18 @@ def to_fake_tensor(t, fake_mode): ) +# NB: this works for both classes and instances +def is_frozen_dataclass(value): + return ( + not object_has_getattribute(value) + and not class_has_getattribute(value) + and is_dataclass(value) + and hasattr(value, "__dataclass_params__") + and hasattr(value.__dataclass_params__, "frozen") + and value.__dataclass_params__.frozen + ) + + def get_first_attr(obj, *attrs): """ Return the first available attribute or throw an exception if none is present. @@ -3124,14 +3098,10 @@ def is_parameter_freezing(): return torch._inductor.config.freezing and not torch.is_grad_enabled() -def get_torch_function_mode_stack(filter_ignored=True): - from .variables.torch_function import IGNORED_MODES - - stack = [_get_function_stack_at(i) for i in range(_len_torch_function_stack())] - if filter_ignored: - stack = [mode for mode in stack if type(mode) not in IGNORED_MODES] - - return stack +def get_torch_function_mode_stack(): + return [ + get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack()) + ] def get_torch_function_mode_stack_at(ind): @@ -3147,6 +3117,11 @@ def set_torch_function_mode_stack(stack): _push_on_torch_function_stack(mode) +def clear_torch_function_mode_stack(): + for i in range(_len_torch_function_stack()): + _pop_torch_function_stack() + + def verify_guard_fn_signature(value): fn = value.__metadata_guard__ sig = inspect.signature(fn) @@ -3203,3 +3178,41 @@ def get_user_object_from_id(obj_id): def store_user_object_weakref(obj): obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) + + +class CompileTimeInstructionCounter: + _counter: int = 0 + _id: int = -1 + _depth = 0 + + @classmethod + def start(cls) -> None: + cls._depth = cls._depth + 1 + if cls._depth == 1: + cls._id = _instruction_counter.start() + + @classmethod + def end(cls) -> None: + cls._depth = cls._depth - 1 + if cls._depth == 0: + cls._counter += _instruction_counter.end(cls._id) + cls._id = -1 + + @classmethod + def clear(cls) -> None: + cls._counter = 0 + + @classmethod + def value(cls) -> int: + return cls._counter + + @classmethod + @contextmanager + def record(cls): + try: + if config.record_compile_time_instruction_count: + cls.start() + yield + finally: + if config.record_compile_time_instruction_count: + cls.end() diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index a6a46bb3932004..5a8522e68c4c06 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -14,6 +14,7 @@ GradModeVariable, InferenceModeVariable, JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, SetFwdGradEnabledContextManager, StreamContextVariable, StreamVariable, @@ -24,6 +25,7 @@ ConstDictVariable, CustomizedDictVariable, DefaultDictVariable, + FrozensetVariable, SetVariable, ) from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable @@ -45,7 +47,9 @@ CycleIteratorVariable, IteratorVariable, ItertoolsVariable, + MapVariable, RepeatIteratorVariable, + ZipVariable, ) from .lazy import LazyVariableTracker from .lists import ( diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index e7a3b7320e5c62..723c5a90c66ac6 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -207,7 +207,10 @@ def python_type(self): Raises: NotImplementedError: If the method is not implemented in a subclass. """ - raise NotImplementedError(f"{self} has no type") + try: + return type(self.as_python_constant()) + except NotImplementedError: + raise NotImplementedError(f"{self} has no type") from None def as_python_constant(self): """For constants""" @@ -286,6 +289,15 @@ def can_reconstruct(self, tx): def unpack_var_sequence(self, tx) -> List["VariableTracker"]: raise NotImplementedError + def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]: + # like unpack_var_sequence, but should only be used when it is + # safe to eagerly (vs. lazily) unpack this variable. + # e.g. map(f, x) is normally evaluated lazily but sometimes + # we want to force eager unpacking, e.g. when converting to a list. + # NOTE: this method is allowed to mutate the VariableTracker, so + # it should only be called once. + return self.unpack_var_sequence(tx) + def has_unpack_var_sequence(self, tx) -> bool: try: self.unpack_var_sequence(tx) @@ -293,6 +305,10 @@ def has_unpack_var_sequence(self, tx) -> bool: except NotImplementedError: return False + # NB: don't call force_unpack_var_sequence, especially if it mutates! + def has_force_unpack_var_sequence(self, tx) -> bool: + return self.has_unpack_var_sequence(tx) + def inspect_parameter_names(self) -> List[str]: unimplemented(f"inspect_parameter_names: {self}") diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index cfc563307f89ed..46e970cd85335f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -15,8 +15,21 @@ import re import sys import types +import warnings import weakref -from typing import Any, List, MutableMapping, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + List, + MutableMapping, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Union, +) import torch from torch import SymInt @@ -25,7 +38,7 @@ from torch._ops import HigherOrderOperator from torch._streambase import _EventBase, _StreamBase from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode -from torch._subclasses.meta_utils import is_sparse_any +from torch._subclasses.meta_utils import is_sparse_any, safe_grad from torch._utils_internal import justknobs_check from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.symbolic_shapes import ( @@ -47,6 +60,7 @@ from ..guards import GuardBuilder, install_guard, make_dupe_guard from ..side_effects import SideEffects from ..source import ( + AttrProxySource, AttrSource, CallMethodItemSource, ConstantSource, @@ -81,6 +95,7 @@ get_fake_value, get_locals_to_steal, get_static_address_type, + is_frozen_dataclass, is_function_or_wrapper, is_lru_cache_wrapped_function, is_namedtuple, @@ -189,9 +204,11 @@ from .torch_function import ( build_torch_function_fn, TensorWithTFOverrideVariable, + torch_function_mode_stack_state_mgr, TorchFunctionModeVariable, ) from .user_defined import ( + FrozenDataClassVariable, KeyedJaggedTensorVariable, MutableMappingVariable, SourcelessGraphModuleVariable, @@ -220,6 +237,12 @@ DimList = List +def safe_has_grad(t): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return hasattr(t, "grad") + + class _missing: pass @@ -312,6 +335,17 @@ class FrameStateSizeEntry: stride: Optional[List[int]] +# All class-based iterators in itertools +# NOTE: use id() because some objects are not hashable, it will raise error during lookup +ITERTOOLS_TYPE_IDS: FrozenSet[int] = frozenset( + id(member) + for name, member in vars(itertools).items() + if not name.startswith("_") and inspect.isclass(member) +) +# Will be updated later in substitute_in_graph in torch/_dynamo/polyfills/itertools.py +ITERTOOLS_POLYFILLED_TYPE_IDS: Set[int] = set() + + class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" @@ -455,7 +489,9 @@ def wrap_jit_function(self, value): @classmethod @functools.lru_cache(None) - def _id_dispatch(cls): + def _id_dispatch( + cls, + ) -> Dict[int, Callable[["VariableBuilder", Any], VariableTracker]]: from ..comptime import comptime entries = [ @@ -512,10 +548,6 @@ class Autotuner: if id_dispatch is not None: return id_dispatch(self, value) - # Note - There are some nested values where types mismatch! - # We want to get those out and wrap those. - value = inspect.getattr_static(value, "_torchdynamo_inline", value) - # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses @@ -794,7 +826,7 @@ def build_key_value(i, k, v): self.install_guards(GuardBuilder.ID_MATCH) stream_proxy = self.tx.output.create_proxy( "call_function", - torch.cuda.Stream, + type(value), (), { "stream_id": value.stream_id, @@ -812,6 +844,9 @@ def build_key_value(i, k, v): elif isinstance(value, (torch._C._SDPAParams)): self.install_guards(GuardBuilder.TYPE_MATCH) return SDPAParamsVariable.create(self.tx, value, self.source) + elif isinstance(value, torch._C._SDPBackend): + self.install_guards(GuardBuilder.ID_MATCH) + return ConstantVariable(value) elif isinstance(value, _EventBase): self.install_guards(GuardBuilder.ID_MATCH) torch._dynamo.utils.store_user_object_weakref(value) @@ -867,7 +902,10 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif istype(value, type) and value in itertools.__dict__.values(): + elif ( + id(value) in ITERTOOLS_TYPE_IDS + and id(value) not in ITERTOOLS_POLYFILLED_TYPE_IDS + ): self.install_guards(GuardBuilder.FUNCTION_MATCH) return ItertoolsVariable(value, source=self.source) elif isinstance(value, torch.SymBool): @@ -949,6 +987,13 @@ def build_key_value(i, k, v): elif is_lru_cache_wrapped_function(value): self.install_guards(GuardBuilder.TYPE_MATCH) return WrapperUserFunctionVariable(value, "__wrapped__", source=self.source) + elif is_function_or_wrapper(value) and inspect.getattr_static( + value, "_torchdynamo_inline", False + ): + self.install_guards(GuardBuilder.TYPE_MATCH) + return WrapperUserFunctionVariable( + value, "_torchdynamo_inline", source=self.source + ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, @@ -1125,6 +1170,10 @@ def build_key_value(i, k, v): elif issubclass(type(value), MutableMapping): self.install_guards(GuardBuilder.TYPE_MATCH) return MutableMappingVariable(value, source=self.source) + elif is_frozen_dataclass(value): + self.install_guards(GuardBuilder.TYPE_MATCH) + result = FrozenDataClassVariable.create(self.tx, value, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) else: return self.wrap_user_defined(value) @@ -1149,8 +1198,22 @@ def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): if item is value: unimplemented("list elements are pointing to the list itself") + # Tuples are immutable objects, so we should mark its items static. This + # avoids wrapping of tuple items as symints. This helps for nn module + # attributes like conv2d strides, dilations. + if ( + istype(value, tuple) + and all(ConstantVariable.is_literal(item) for item in value) + and self.source.guard_source().is_unspecialized_nn_module() + ): + self.install_guards(GuardBuilder.CONSTANT_MATCH) + return TupleVariable([ConstantVariable.create(item) for item in value]) + output = [ - LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i)) + LazyVariableTracker.create( + item, + source=GetItemSource(self.get_source(), i), + ) for i, item in enumerate(value) ] @@ -1309,6 +1372,13 @@ def wrap_module(self, value: torch.nn.Module): return self.tx.output.side_effects.track_object_existing(value, result) elif mutation_guard.is_dynamic_nn_module(value, self.tx.export): # created dynamically, don't specialize on it + + # Note [Tracing a torch.compiled function] + # when make_fx tracing a compiled function, we need + if isinstance(value, torch.fx.experimental.proxy_tensor._AttrProxy): + value = value.get_base() + self.source = AttrProxySource(self.source) + self.install_guards(GuardBuilder.TYPE_MATCH) if torch._dynamo.config.inline_inbuilt_nn_modules: freezing = is_parameter_freezing() @@ -1324,7 +1394,9 @@ def wrap_module(self, value: torch.nn.Module): # this will get cleaned up once compile ends self.tx.output.nn_modules[self.name] = value - if value.__module__.startswith(("torch.nn.", "torch.ao.")): + if value.__module__.startswith(("torch.nn.", "torch.ao.")) or getattr( + value.__class__, "_dynamo_marked_static", False + ): result = UnspecializedBuiltinNNModuleVariable(value, source=self.source) else: result = UnspecializedNNModuleVariable(value, source=self.source) @@ -1408,7 +1480,8 @@ def wrap_tensor(self, value: torch.Tensor): or (source and source.guard_source().is_unspecialized_nn_module()) ) ): - self.mark_static_input(value, guard=False) + self.mark_static_input(value, guard=is_parameter_freezing()) + is_static_input = True make_graph_attribute = is_static_input and ( not config.inline_inbuilt_nn_modules or is_parameter_freezing() @@ -1511,6 +1584,18 @@ def wrap_tensor(self, value: torch.Tensor): # SPARSE_TENSOR_GUARDS for guards to work propertly. unimplemented("torch.compile does not support sparse Tensors") + if ( + safe_has_grad(value) + and safe_grad(value) is not None + and value.dtype != safe_grad(value).dtype + ): + unimplemented( + "Inconsistent dtype between tensor and its gradient. " + "This can happen in FSDP and crashes meta tensor creation. " + "This is potentially a workaround. Fixing it correctly " + "requires some design around FSDP + torch.compile." + ) + tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, @@ -1584,15 +1669,16 @@ def wrap_numpy_ndarray(self, value): # but warning is not the end of the world assert isinstance(value.base, np.nditer) - try: - tensor_value = _util._try_convert_to_tensor(value) - if readonly: - from torch._prims_common import clone_preserve_strides + with torch_function_mode_stack_state_mgr.temp_restore_stack(): + try: + tensor_value = _util._try_convert_to_tensor(value) + if readonly: + from torch._prims_common import clone_preserve_strides - tensor_value = clone_preserve_strides(tensor_value) - except NotImplementedError as e: - # failed to convert to tensor, graph break - unimplemented(str(e)) + tensor_value = clone_preserve_strides(tensor_value) + except NotImplementedError as e: + # failed to convert to tensor, graph break + unimplemented(str(e)) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here @@ -1675,6 +1761,15 @@ def update_frame_state(value): value, frame_state_entry.scalar, ) + if self.source.guard_source().is_unspecialized_nn_module(): + log.info( + "%s", + ( + f"{name} is converted to a symbolic integer. It is an attribute of a " + "user defined nn module class. If you wish to keep it static, you can " + "mark the nn module class as `torch._dynamo.mark_static`." + ), + ) frame_state_entry.scalar = None self.tx.output.frame_state[name] = frame_state_entry @@ -1690,7 +1785,8 @@ def update_frame_state(value): else: # Apply the updates for sub_state in st.all_states: - update_frame_state(sub_state.input_sizes[name]) + if name in sub_state.input_sizes: + update_frame_state(sub_state.input_sizes[name]) frame_state_entry = self.tx.output.frame_state[name] # TODO: This should be dynamic, as we in general do not @@ -2228,6 +2324,7 @@ def _clone_input(value): set_example_value(proxy.node, example_value) return SDPAParamsVariable(proxy, **options) elif isinstance(example_value, bool) and proxy.node.target in [ + torch._C._are_functorch_transforms_active, torch.backends.cuda.is_flash_attention_available, torch.backends.cuda.can_use_flash_attention, torch.backends.cuda.can_use_efficient_attention, @@ -2347,6 +2444,9 @@ def _automatic_dynamic( # Prep for automatic dynamic def update_frame_state(size, stride): + # Intentionally shadow e from parent scope so it is not accidentally + # called + e = None frame_state_entry = None if name not in tx.output.frame_state: # If there is no entry for this source, add the tensor to frame state with its current static size. @@ -2357,13 +2457,13 @@ def update_frame_state(size, stride): else: frame_state_entry = tx.output.frame_state[name] if frame_state_entry.size is not None: - if e.ndim != len(frame_state_entry.size): + if len(size) != len(frame_state_entry.size): # If there is already an entry, and the dim mismatches, replace the frame state entry with None. # E.g. {"x": [2, 3, 4]} -> {"x": None} log.debug( "automatic dynamic %s dim %s != %s", name, - e.ndim, + len(size), frame_state_entry.size, ) frame_state_entry.size = None @@ -2380,7 +2480,7 @@ def update_frame_state(size, stride): "automatic dynamic %s size(%s) %s != %s", name, i, - e.size(i), + size[i], dim, ) frame_state_entry.size[i] = None @@ -2410,7 +2510,7 @@ def update_frame_state(size, stride): "automatic dynamic %s stride(%s) %s != %s", name, i, - e.stride(i), + stride[i], dim, ) frame_state_entry.stride[i] = None @@ -2430,9 +2530,11 @@ def update_frame_state(size, stride): else: # Apply the updates for sub_state in st.all_states: - update_frame_state( - sub_state.input_sizes[name], sub_state.input_strides[name] - ) + # Not all inputs are necessarily present on all ranks + if name in sub_state.input_sizes and name in sub_state.input_strides: + update_frame_state( + sub_state.input_sizes[name], sub_state.input_strides[name] + ) frame_state_entry = tx.output.frame_state[name] # TODO: index export_constraints ahead of time so we don't have to diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index ca972b282e387b..296eb646187c12 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -49,6 +49,7 @@ ConstDictVariable, DefaultDictVariable, DictView, + FrozensetVariable, is_hashable, SetVariable, ) @@ -104,7 +105,7 @@ class BuiltinVariable(VariableTracker): @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) - return BuiltinVariable(value, source=source) + return cls(value, source=source) @staticmethod @functools.lru_cache(None) @@ -636,9 +637,6 @@ def __str__(self) -> str: return f"{self.__class__.__name__}({name})" - def python_type(self): - return type(self.fn) - def as_python_constant(self): return self.fn @@ -994,15 +992,6 @@ def call_method( ) if self.fn is dict and name == "fromkeys": return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) - if self.fn is itertools.chain and name == "from_iterable": - assert len(args) == 1 - assert len(kwargs) == 0 - obj = args[0] - items = [] - for item in obj.unpack_var_sequence(tx): - items.extend(item.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - return super().call_method(tx, name, args, kwargs) def _call_int_float(self, tx: "InstructionTranslator", arg): @@ -1069,9 +1058,8 @@ def call_str(self, tx: "InstructionTranslator", arg): return tx.inline_user_function_return(user_func_variable, [arg], {}) def _call_min_max(self, tx: "InstructionTranslator", *args): - if len(args) == 1 and args[0].has_unpack_var_sequence(tx): - # expand iterable - items = args[0].unpack_var_sequence(tx) + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) return self._call_min_max_seq(tx, items) elif len(args) == 2: return self._call_min_max_binary(tx, args[0], args[1]) @@ -1086,6 +1074,10 @@ def _call_min_max_seq(self, tx: "InstructionTranslator", items): return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): + if a is None or b is None: + # a or b could be none if we reduce and _call_min_max_binary failed + # to return something + return if self.tensor_args(a, b): if not isinstance(a, variables.TensorVariable): a, b = b, a @@ -1234,17 +1226,15 @@ def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): ), ) + # NOTE must handle IteratorVariable separately! def _call_iter_tuple_list( self, tx: "InstructionTranslator", obj=None, *args, **kwargs ): + assert not isinstance(obj, variables.IteratorVariable) + if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) - if isinstance(obj, variables.IteratorVariable): - # For non-list iterators, we will guard on vars that - # determine the control flow - return obj - cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -1272,9 +1262,22 @@ def _call_iter_tuple_list( mutable_local=MutableLocal(), ) + def _call_tuple_list(self, tx, obj=None, *args, **kwargs): + if isinstance(obj, variables.IteratorVariable): + cls = variables.BaseListVariable.cls_for(self.fn) + return cls( + list(obj.force_unpack_var_sequence(tx)), + mutable_local=MutableLocal(), + ) + else: + return self._call_iter_tuple_list(tx, obj, *args, **kwargs) + def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): - # Handle the case where we are iterating over a tuple, list or iterator - ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) + if isinstance(obj, variables.IteratorVariable): + ret = obj + else: + # Handle the case where we are iterating over a tuple, list or iterator + ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) if ret is None: # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. @@ -1283,8 +1286,8 @@ def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): return obj.call_method(tx, "__iter__", args, kwargs) return ret - call_tuple = _call_iter_tuple_list - call_list = _call_iter_tuple_list + call_tuple = _call_tuple_list + call_list = _call_tuple_list def call_callable(self, tx: "InstructionTranslator", arg): from .functions import BaseUserFunctionVariable @@ -1342,10 +1345,12 @@ def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): ListVariable, TupleVariable, ListIteratorVariable, + variables.IteratorVariable, ), ): items = dict( - x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) + x.force_unpack_var_sequence(tx) + for x in arg.force_unpack_var_sequence(tx) ) return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) elif isinstance(arg, variables.MutableMappingVariable): @@ -1402,13 +1407,12 @@ def call_custom_dict_fromkeys( return DictVariableType( dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() ) - elif arg.has_unpack_var_sequence(tx) and all( - is_hashable(v) for v in arg.unpack_var_sequence(tx) - ): - keys = arg.unpack_var_sequence(tx) - return DictVariableType( - dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() - ) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return DictVariableType( + dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() + ) unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") def call_set(self, tx: "InstructionTranslator", *args, **kwargs): @@ -1420,8 +1424,8 @@ def call_set(self, tx: "InstructionTranslator", *args, **kwargs): arg = args[0] if isinstance(arg, variables.SetVariable): return arg.clone(mutable_local=MutableLocal()) - elif arg.has_unpack_var_sequence(tx): - items = arg.unpack_var_sequence(tx) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) return SetVariable(items, mutable_local=MutableLocal()) elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, KeysView @@ -1437,35 +1441,29 @@ def call_set(self, tx: "InstructionTranslator", *args, **kwargs): else: unimplemented(f"set(): {args} {kwargs}") + def call_frozenset(self, tx: "InstructionTranslator", *args, **kwargs): + assert not kwargs + if not args: + return FrozensetVariable([]) + assert len(args) == 1 + arg = args[0] + if isinstance(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_unpack_var_sequence(tx): + items = arg.unpack_var_sequence(tx) + return FrozensetVariable(items) + else: + unimplemented(f"frozenset(): {args} {kwargs}") + def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): if kwargs: assert len(kwargs) == 1 and "strict" in kwargs - if all(x.has_unpack_var_sequence(tx) for x in args): - unpacked = [arg.unpack_var_sequence(tx) for arg in args] - if kwargs.pop("strict", False) and len(unpacked) > 0: - if not all(len(u) == len(unpacked[0]) for u in unpacked): - raise UserError( - ValueError, - "zip() has one argument of len differing from others", - ) - items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] - return variables.TupleVariable(items) - - def call_enumerate(self, tx: "InstructionTranslator", *args): - if len(args) == 1: - start = 0 - else: - assert len(args) == 2 - assert isinstance(args[1], variables.ConstantVariable) - start = args[1].as_python_constant() - if args[0].has_unpack_var_sequence(tx): - items = [ - variables.TupleVariable( - [variables.ConstantVariable.create(idx), var], - ) - for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) - ] - return variables.TupleVariable(items) + strict = kwargs.pop("strict", False) + args = [ + arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg + for arg in args + ] + return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal()) def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) @@ -1566,10 +1564,11 @@ def call_hasattr(self, tx: "InstructionTranslator", obj, attr): return obj.call_hasattr(tx, name) def call_map(self, tx: "InstructionTranslator", fn, *seqs): - if all(seq.has_unpack_var_sequence(tx) for seq in seqs): - unpacked = [seq.unpack_var_sequence(tx) for seq in seqs] - items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)] - return variables.TupleVariable(items) + seqs = [ + seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq + for seq in seqs + ] + return variables.MapVariable(fn, seqs, mutable_local=MutableLocal()) def call_filter(self, tx: "InstructionTranslator", fn, seq): if seq.has_unpack_var_sequence(tx): @@ -1872,11 +1871,12 @@ def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker): return variables.TupleVariable(items) def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs): - if ( - obj.has_unpack_var_sequence(tx) - and not isinstance(obj, variables.TensorVariable) - and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) + if obj.has_force_unpack_var_sequence(tx) and not isinstance( + obj, variables.TensorVariable ): + unpacked = obj.force_unpack_var_sequence(tx) + if not all(x.is_python_constant() for x in unpacked): + return function = kwargs.pop("key", None) reverse = kwargs.pop( "reverse", ConstantVariable.create(False) @@ -1884,7 +1884,7 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg assert len(kwargs) == 0 if function: items = sorted( - obj.unpack_var_sequence(tx), + unpacked, key=lambda x: function.call_function( tx, [x], {} ).as_python_constant(), @@ -1892,28 +1892,12 @@ def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwarg ) else: items = sorted( - obj.unpack_var_sequence(tx), + unpacked, key=lambda x: x.as_python_constant(), reverse=reverse, ) return variables.ListVariable(items) - def call_chain(self, tx: "InstructionTranslator", *args): - if all(obj.has_unpack_var_sequence(tx) for obj in args): - items = [] - for obj in args: - items.extend(obj.unpack_var_sequence(tx)) - return variables.TupleVariable(items) - - def call_islice(self, tx: "InstructionTranslator", iterable, *args): - if iterable.has_unpack_var_sequence(tx) and all( - x.is_python_constant() for x in args - ): - const_args = [x.as_python_constant() for x in args] - items = iterable.unpack_var_sequence(tx) - items = list(itertools.islice(items, *const_args)) - return variables.TupleVariable(items) - # neg is a constant fold function, so we only get here if constant fold is not valid def call_neg(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 8f62a3585bcecf..de357cf8094f3a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -85,9 +85,6 @@ def as_proxy(self): def __str__(self) -> str: return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -148,6 +145,14 @@ def call_method( return variables.BuiltinVariable(str.format).call_function( tx, [self, *args], kwargs ) + elif name == "join" and istype(self.value, str): + assert len(args) == 1 and len(kwargs) == 0 + arg_unpacked = args[0].force_unpack_var_sequence(tx) + try: + arg_const = [x.as_python_constant() for x in arg_unpacked] + return ConstantVariable.create(self.value.join(arg_const)) + except NotImplementedError: + return super().call_method(tx, name, args, kwargs) if any(isinstance(x, SymNodeVariable) for x in args): # Promote to SymNodeVariable for operations involving dynamic shapes. @@ -217,14 +222,13 @@ def create(cls, cls_type, value_vt, options): unimplemented("Enum variable is constructed with non constant values") def as_proxy(self): + if isinstance(self.value, int): + return int(self.value) # convert IntEnum to a normal int return self.value def __str__(self) -> str: return f"EnumVariable({type(self.value)})" - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 301b7f3e819345..e19c4e254c647e 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -125,6 +125,12 @@ def call_function( if isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return True + class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are @@ -183,6 +189,12 @@ def exit(self, tx: "InstructionTranslator", *args): tx.generic_context_manager_depth -= 1 return x + def supports_graph_breaks(self): + return False + + def exit_on_graph_break(self): + return True + class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requries grad""" @@ -637,6 +649,8 @@ def enter(self, tx): def _call_func(self, tx: "InstructionTranslator", values): assert len(values) == 1 + tx.symbolic_torch_function_state.torch_function_subclass_enabled = values[0] + tx.symbolic_torch_function_state.torch_function_mode_enabled = values[0] tx.output.set_torch_function_state(values[0]) @@ -977,6 +991,80 @@ def fn_name(self): return "use_training_state" +class SDPAKernelVariable(ContextWrappingVariable): + """represents torch.nn.attention.sdpa_kernel""" + + @staticmethod + def create(tx: "InstructionTranslator", backends, **kwargs): + if isinstance(backends, torch.nn.attention.SDPBackend): + backends = [backends] + var = SDPAKernelVariable( + target_values=backends, + initial_values=None, + **kwargs, + ) + return var + + def __init__( + self, + target_values: List[torch.nn.attention.SDPBackend], + initial_values=None, + **kwargs, + ) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + @staticmethod + def _backends_to_nodes(tx, backends): + nodes = [] + for backend in backends: + # convert to/from string in order to bake the backend into FX graph + nodes.append( + tx.output.create_node( + "call_function", + torch.nn.attention._backend_from_string, + (backend.name,), + {}, + ) + ) + return nodes + + def enter(self, tx): + self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends() + self.set_cleanup_hook( + tx, lambda: torch.nn.attention._sdpa_kernel(self.prev_backends) + ) + torch.nn.attention._sdpa_kernel(self.target_values) + arg = self._backends_to_nodes(tx, self.target_values) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg,), + {}, + ) + return variables.ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + self.state.cleanup_assert() + arg = self._backends_to_nodes(tx, self.prev_backends) + tx.output.create_node( + "call_function", + torch.nn.attention._sdpa_kernel, + (arg,), + {}, + ) + return variables.ConstantVariable.create(None) + + def module_name(self): + return "torch.nn.attention" + + # use a private version of sdpa_kernel that accepts variadic arguments + # since dynamo reconstructs the contents of target_values one-by-one + def fn_name(self): + return "_sdpa_kernel_variadic" + + class StreamVariable(VariableTracker): def __init__(self, proxy, value, device, **kwargs) -> None: if proxy is not None and "example_value" in proxy.node.meta: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 2d4c015bf3ac1b..e8323ec6f70d76 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -304,33 +304,37 @@ def call_method( tx.output.side_effects.mutation(self) self.items.clear() return ConstantVariable.create(None) - elif ( - name == "update" - and len(args) == 1 - and isinstance( + elif name == "update" and self.mutable_local: + is_args_supported = len(args) == 1 and isinstance( args[0], ( ConstDictVariable, ListVariable, TupleVariable, ListIteratorVariable, + variables.IteratorVariable, UserDefinedObjectVariable, ), ) - and self.mutable_local - ): - tx.output.side_effects.mutation(self) - if isinstance(args[0], ConstDictVariable): - dict_vt = args[0] + + is_kwargs_supported = len(kwargs) > 0 and len(args) == 0 + + if is_args_supported or is_kwargs_supported: + tx.output.side_effects.mutation(self) + if len(args) == 1: + if isinstance(args[0], ConstDictVariable): + dict_vt = args[0] + else: + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) + self.items.update(dict_vt.items) + # Wrap strings + kwargs = { + Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() + } + self.items.update(kwargs) + return ConstantVariable.create(None) else: - dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) - self.items.update(dict_vt.items) - # Wrap strings - kwargs = { - Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() - } - self.items.update(kwargs) - return ConstantVariable.create(None) + return super().call_method(tx, name, args, kwargs) elif name in ("get", "__getattr__") and args[0] in self: return self.getitem_const(tx, args[0]) elif name == "__contains__" and len(args) == 1: @@ -355,6 +359,15 @@ def call_method( def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] + def call_hasattr(self, tx, name): + # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. + # OrderedDict though requires side effects tracking because it supports arbitrary setattr. + if self.user_cls is dict: + if name in self.user_cls.__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + unimplemented(f"hasattr on {self.user_cls} is not supported") + class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: @@ -531,6 +544,53 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): raise RuntimeError("Illegal to getitem on a set") +class FrozensetVariable(SetVariable): + def __init__( + self, + items: List[VariableTracker], + **kwargs, + ) -> None: + super().__init__(items, **kwargs) + + def debug_repr(self): + if not self.items: + return "frozenset()" + else: + return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" + + @property + def set_items(self): + return self.items.keys() + + def python_type(self): + return frozenset + + def as_python_constant(self): + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen): + codegen.foreach([x.vt for x in self.set_items]) + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global("frozenset"), + ] + ) + ) + codegen.extend_output(create_call_function(0, False)) + + def call_method( + self, + tx, + name, + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> "VariableTracker": + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a frozenset") + return super().call_method(tx, name, args, kwargs) + + class DictView(VariableTracker): """ Models _PyDictViewObject @@ -662,16 +722,6 @@ def _call_hasattr_customobj( ) -class DataClassVariable(ConstDictVariable): - """ - This class doesn't appear to be used anywhere. - It used to be used to deal with transformers.file_utils.ModelOutput - from huggingface. - - Keeping since we wish to support dataclasses in general in the future - """ - - class CustomizedDictVariable(ConstDictVariable): @staticmethod def is_matching_cls_hf(cls): @@ -905,9 +955,15 @@ def __init__(self, obj, **kwargs) -> None: assert self.is_matching_cls(type(obj)) def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - from . import ConstantVariable + from .builder import VariableBuilder + + try: + attr_value = getattr(self.obj, name) + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(attr_value) - return ConstantVariable.create(getattr(self.obj, name)) + except AttributeError: + unimplemented(f"getattr({self.value}, {name})") def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": return variables.ConstantVariable.create(hasattr(self.obj, name)) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 1c79ce094822b0..c14b8794cba5f7 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -50,7 +50,7 @@ def is_available(): def is_from_local(value): if not DistributedVariable.is_available(): return False - from torch.distributed._tensor import DTensor + from torch.distributed.tensor import DTensor return inspect.isfunction(value) and value is DTensor.from_local @@ -108,7 +108,7 @@ def is_placement_type(value): if not DistributedVariable.is_available(): return False - from torch.distributed._tensor.placement_types import Placement + from torch.distributed.tensor.placement_types import Placement return type(value) is type and issubclass(value, Placement) @@ -143,7 +143,7 @@ def is_placement(value): if not DistributedVariable.is_available(): return False - from torch.distributed._tensor.placement_types import Placement + from torch.distributed.tensor.placement_types import Placement return isinstance(value, Placement) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 694d0247a00158..6c1e4a13c94595 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -5,7 +5,7 @@ import inspect import itertools import types -from typing import Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar, Union import torch @@ -16,7 +16,9 @@ from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource from ..utils import ( check_constant_args, + check_unspec_or_constant_args, identity, + is_function, is_wrapper_or_member_descriptor, istype, make_cell, @@ -36,6 +38,9 @@ from torch._guards import Source +_F = TypeVar("_F", bound=Callable) + + def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): # Source propagation is best effort since not every object we encounter has a source to begin with. if isinstance(val, VariableTracker): @@ -134,10 +139,7 @@ class UserFunctionVariable(BaseUserFunctionVariable): @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) - return cls( - value, - source=source, - ) + return cls(value, source=source) def __init__(self, fn, is_constant=False, **kwargs) -> None: super().__init__(**kwargs) @@ -150,6 +152,8 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None: assert isinstance( fn, (types.FunctionType, torch.jit.ScriptFunction) ), f"expected FunctionType found {typestr(fn)} {fn}" + # TODO(anijain2305) - Replace directly calling UserFunctionVariable with + # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) self.fn: types.FunctionType = fn @@ -318,7 +322,20 @@ def call_function( return invoke_and_store_as_constant( tx, self.fn, self.get_name(), args, kwargs ) - + if ( + tx.output.current_tracer.under_activation_checkpoint + and not tx.output.current_tracer.allow_side_effects_under_checkpoint + ): + try: + from torch.distributed._composable.fsdp._fsdp_state import FSDPState + except Exception: + FSDPState = None + if FSDPState is not None and self.fn in [ + FSDPState._pre_forward, + FSDPState._post_forward, + ]: + with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): + return super().call_function(tx, args, kwargs) return super().call_function(tx, args, kwargs) @@ -624,9 +641,6 @@ def __init__(self, value, reason=None, **kwargs) -> None: self.value = value self.reason = reason - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -637,10 +651,7 @@ def create_with_source(cls, value, source): # attribute lookup. They are unlikely to be changed, so we can skip # guarding them. install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) - return cls( - value, - source=source, - ) + return cls(value, source=source) @staticmethod @functools.lru_cache(None) @@ -932,24 +943,53 @@ def guard_as_python_constant(self): class PolyfilledFunctionVariable(VariableTracker): _nonvar_fields = { "fn", - *BaseUserFunctionVariable._nonvar_fields, + "wrapped_fn", + "traceable_fn", + *VariableTracker._nonvar_fields, } @classmethod @functools.lru_cache(None) - def _get_polyfill_handlers(cls): + def _get_polyfill_handlers(cls) -> Dict[Callable[..., Any], types.FunctionType]: return {} @classmethod def create_with_source(cls, value, source): - return cls( - value, - source=source, - ) + install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) + + return cls(value, source=source) - def __init__(self, fn: VariableTracker, **kwargs) -> None: + def __init__(self, fn: _F, **kwargs) -> None: super().__init__(**kwargs) - self.fn = fn + self.fn: _F = fn + + handler = self._get_polyfill_handlers().get(fn, fn) + assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" + for candidate_attr in ( + "__torch_dynamo_polyfill__", # registered polyfill + "__python_implementation__", # self handler from third-party libraries + ): + candidate = getattr(handler, candidate_attr, None) + if candidate: + assert callable(candidate) + traceable_fn = candidate + break + else: + raise RuntimeError( + f"Polyfill handler {handler} does not have a traceable function" + ) + + self.wrapped_fn: _F = handler + self.traceable_fn: _F = traceable_fn + + @property + def polyfill_fn(self) -> _F: + return self.traceable_fn + + def can_constant_fold_through(self): + return getattr( + self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False + ) def get_function(self): return self.as_python_constant() @@ -962,25 +1002,38 @@ def call_function( ) -> "VariableTracker": from torch._dynamo.variables.builder import SourcelessBuilder - handler = self._get_polyfill_handlers().get(self.fn) - if handler: - assert callable(handler) - return SourcelessBuilder.create(tx, handler).call_function(tx, args, kwargs) - - for candidate in ("__torch_dynamo_polyfill__", "__python_implementation__"): - handler = getattr(self.fn, candidate, None) - if handler: - assert callable(handler) - if self.source: - source = AttrSource(self.source, candidate) - return UserFunctionVariable.create_with_source( - handler, - source=source, - ).call_function(tx, args, kwargs) - return SourcelessBuilder.create( - tx, - handler, - ).call_function(tx, args, kwargs) + if self.can_constant_fold_through() and check_unspec_or_constant_args( + args, kwargs + ): + result = ( + self.fn( # use the original function which is faster than the polyfill + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + ) + return SourcelessBuilder.create(tx, result) + + traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn) + return traceable_function_variable.call_function(tx, args, kwargs) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "__call__": + return self.call_function(tx, args, kwargs) + + method = getattr(self.fn, name, None) + assert method is not None, f"Member {name} not found in {self.fn}" + assert is_function(method), f"Member {name} is not callable in {self.fn}" + options = {} + if self.source: + options["source"] = AttrSource(self.source, name) + polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) + return polyfilled_method_variable.call_function(tx, args, kwargs) def as_python_constant(self): return self.fn diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 7aeb279dfb9704..c8d82b15134461 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -24,7 +24,12 @@ from torch.utils import _pytree as pytree from .. import variables -from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported +from ..exc import ( + IncorrectUsage, + UncapturedHigherOrderOpError, + unimplemented, + Unsupported, +) from ..source import AttrSource from ..utils import proxy_args_kwargs from .dicts import ConstDictVariable @@ -66,6 +71,16 @@ def dynamo_enable_grad(tx: "InstructionTranslator", enable=True): GradModeVariable.create(tx, org_value, initialized=True) +@contextlib.contextmanager +def dynamo_under_activation_checkpoint(tx: "InstructionTranslator"): + orig_val = tx.output.current_tracer.under_activation_checkpoint + try: + tx.output.current_tracer.under_activation_checkpoint = True + yield + finally: + tx.output.current_tracer.under_activation_checkpoint = orig_val + + def only_consist_of(var, types, allow_none=False): if isinstance(var, types): return True @@ -383,6 +398,7 @@ def speculate_subgraph( set_subgraph_inputs="automatic", restore_side_effects=True, should_flatten_outputs=False, + under_activation_checkpoint=False, # Pass in an originating tracer - this is needed for preserving context # across fwd-bwd for autograd.Function tracer=None, @@ -434,6 +450,11 @@ def speculate_subgraph( if enable_grad is not None else contextlib.nullcontext() ) + checkpoint_ctx = ( + dynamo_under_activation_checkpoint(tx) + if under_activation_checkpoint + else contextlib.nullcontext() + ) # For handling side effects, we can make an argument that we don't # have to do anything here. The side effects infra does a good job @@ -453,7 +474,7 @@ def speculate_subgraph( if restore_side_effects: prev_side_effects = tx.output.side_effects.clone() - with autograd_ctx: + with autograd_ctx, checkpoint_ctx: output = f.call_function(tx, args, sub_kwargs) if restore_side_effects: @@ -578,6 +599,8 @@ def make(value, source=None, **kwargs): return OutDtypeHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "wrap": return WrapHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "hints_wrapper": + return HintsWrapperHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "flex_attention": return FlexAttentionHigherOrderVariable(value, source, **kwargs) elif value.__name__ in ( @@ -595,10 +618,14 @@ def make(value, source=None, **kwargs): return RunWithRNGStateHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "associative_scan": return AssociativeScanHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "scan": + return ScanHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "call_torchbind": return CallTorchbindHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "wrap_with_set_grad_enabled": return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "auto_functionalized": + return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs) else: unimplemented(f"HigherOrderOperator {value.__name__}") @@ -1013,23 +1040,39 @@ def call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) - def arg_extractor(combine_fn, input, dim): - return combine_fn, input, dim + def arg_extractor(combine_fn, xs, dim): + return combine_fn, xs, dim - combine_fn, input, dim = arg_extractor(*args, **kwargs) + combine_fn, xs, dim = arg_extractor(*args, **kwargs) - if input.python_type() != list: + if xs.python_type() != list: unimplemented( - f"Expected input to be a list of tensors but got {input.python_type()}", + f"Expected xs to be a list of tensors but got {xs.python_type()}", ) - assert isinstance(input, torch._dynamo.variables.lists.BaseListVariable) + assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable) # Trace the subgraph # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. - null_shape = SourcelessBuilder.create(tx, ()) sub_args = [ - leaf.call_method(tx, "new_empty", args=(null_shape,), kwargs={}) - for leaf in itertools.chain(input.items, input.items) + leaf.call_method( + tx, + "new_empty", + args=( + SourcelessBuilder.create( + tx, + leaf.size + if leaf.size is not None + else BuiltinVariable(getattr) + .call_function(tx, [leaf, ConstantVariable.create("shape")], {}) + .items, + ), + ), + kwargs={ + "dtype": SourcelessBuilder.create(tx, leaf.dtype), + "requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad), + }, + ) + for leaf in itertools.chain(xs.items, xs.items) ] ( (combine_result, combine_treespec), @@ -1040,7 +1083,7 @@ def arg_extractor(combine_fn, input, dim): combine_fn, sub_args, sub_kwargs={}, - description="scan_combine", + description="associative_scan_combine_fn", source_target=self.value, set_subgraph_inputs="flatten_manual", ) @@ -1055,9 +1098,9 @@ def arg_extractor(combine_fn, input, dim): f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}", ) - input_proxy = input.as_proxy() + xs_proxy = xs.as_proxy() combine_result_proxy = combine_result.as_proxy() - for result, inp_proxy in zip(combine_result_proxy, input_proxy): + for result, inp_proxy in zip(combine_result_proxy, xs_proxy): inp_meta = inp_proxy.node.meta["example_value"] combine_result_meta = result.node.meta["example_value"] if combine_result_meta.device != inp_meta.device: @@ -1071,29 +1114,213 @@ def arg_extractor(combine_fn, input, dim): + f"got {combine_result_meta.dtype}" ) - if combine_result_meta.shape != (): + combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) + combine_fn_name = add_subgraph(tx, "associative_scan_combine_fn", combine_gm) + + p_args = ( + make_attr(tx, combine_fn_name), + xs_proxy, + dim.as_proxy(), + ) + + with tx.fake_mode: + out_meta = tuple( + inp_proxy.node.meta["example_value"].clone() for inp_proxy in xs_proxy + ) + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", torch.ops.higher_order.associative_scan, p_args, {} + ), + example_value=out_meta, + ) + + +class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="scan must be captured completely with torch.compile." + ) + def call_function( + self, + tx: "InstructionTranslator", + args: List[VariableTracker], + kwargs: Dict[str, VariableTracker], + ) -> VariableTracker: + from torch._higher_order_ops.scan import make_expanded_output_shape + + from .builder import SourcelessBuilder, wrap_fx_proxy + + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + def arg_extractor(combine_fn, init, xs, dim, reverse): + return combine_fn, init, xs, dim, reverse + + combine_fn, init, xs, dim, reverse = arg_extractor(*args, **kwargs) + + if xs.python_type() != list: + unimplemented( + f"Expected xs to be a list of tensors but got {xs.python_type()}", + ) + assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable) + if init.python_type() != list: + unimplemented( + f"Expected init to be a list of tensors but got {init.python_type()}", + ) + assert isinstance(init, torch._dynamo.variables.lists.BaseListVariable) + + dim_fake = ( + dim.as_proxy() + if type(dim.as_proxy()) == int + else get_fake_value(dim.as_proxy().node, tx) + ) + scan_length = get_fake_value(xs.items[0].as_proxy().node, tx).size()[dim_fake] + if scan_length == 0: + unimplemented( + "scan() operator doesn't support zero-sized tensors during tracing." + ) + + init_len = len(init.items) + if init_len == 0: + unimplemented("scan() operator requires init leaves.") + + # Trace the subgraph + # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. + # TODO: Unify handling of sub_args across control flow ops, such as cond, while_loop, etc. + sub_args_init = [ + ini.call_method( + tx, + "new_empty", + args=( + SourcelessBuilder.create( + tx, + ini.size + if ini.size is not None + else tuple( + BuiltinVariable(getattr) + .call_function( + tx, [ini, ConstantVariable.create("shape")], {} + ) + .items + ), + ), + ), + kwargs={ + "dtype": SourcelessBuilder.create(tx, ini.dtype), + "device": SourcelessBuilder.create(tx, ini.device), + "requires_grad": SourcelessBuilder.create(tx, ini.requires_grad), + }, + ) + for ini in init.items + ] + sub_args_inp_shapes = make_expanded_output_shape( + dim_fake, + 1, + [ + tuple( + BuiltinVariable(getattr) + .call_function(tx, [inp, ConstantVariable.create("shape")], {}) + .items + ) + for inp in xs.items + ], + True, + ) + sub_args_inp = [ + inp.call_method( + tx, + "new_empty", + args=(SourcelessBuilder.create(tx, inp_sh),), + kwargs={ + "dtype": SourcelessBuilder.create(tx, inp.dtype), + "device": SourcelessBuilder.create(tx, inp.device), + "requires_grad": SourcelessBuilder.create(tx, inp.requires_grad), + }, + ) + for inp, inp_sh in zip(xs.items, sub_args_inp_shapes) + ] + sub_args = sub_args_init + sub_args_inp + ( + (combine_result, combine_treespec), + combine_graph, + combine_lifted_freevars, + ) = speculate_subgraph( + tx, + combine_fn, + sub_args, + sub_kwargs={}, + description="scan_combine_fn", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + ) + + if combine_lifted_freevars: + unimplemented( + f"Combine fn had unexpected freevars: {combine_lifted_freevars}" + ) + + if any(cr.python_type() != list for cr in combine_result.items): + unimplemented( + f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}", + ) + + xs_proxy = xs.as_proxy() + init_proxy = init.as_proxy() + combine_carry_proxy = combine_result.items[0].as_proxy() + + # Checks for carry and init + for ini_proxy, carry in zip(init_proxy, combine_carry_proxy): + ini_meta = ini_proxy.node.meta["example_value"] + carry_meta = carry.node.meta["example_value"] + if ( + carry_meta.device != ini_meta.device + or carry_meta.dtype != ini_meta.dtype + or carry_meta.shape != ini_meta.shape + ): unimplemented( - f"Expected combine_fn to return a tensor with shape () but got {combine_result_meta.shape}" + f"Expected metadata of the combine_fn result {carry_meta} to be the same as " + + f"the metadata of init with {ini_meta}" ) combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) - combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm) + combine_fn_name = add_subgraph(tx, "scan_combine_fn", combine_gm) p_args = ( make_attr(tx, combine_fn_name), - input_proxy, + init_proxy, + xs_proxy, dim.as_proxy(), + reverse.as_proxy(), ) with tx.fake_mode: - out_meta = tuple( - inp_proxy.node.meta["example_value"].clone() - for inp_proxy in input_proxy + # For the fake mode, we need to duplicate the init tensor along the dim + # to have the same size as the xs arguments + # We also do a clone with contiguous_format. This is to be consistent with + # eager semantic of map, which stacks the outputs. The result is contiguous + # as a result of the stack operation. + fake_out_shapes = make_expanded_output_shape( + dim_fake, + scan_length, + [ + get_fake_value(o.as_proxy().node, tx).size() + for o in combine_result.items[1].items + ], ) + out_meta = ( + [init_p.node.meta["example_value"].clone() for init_p in init_proxy], + list( # noqa: C400 + t.as_proxy() + .node.meta["example_value"] + .expand(*sh) + .clone(memory_format=torch.contiguous_format) + for t, sh in zip(combine_result.items[1].items, fake_out_shapes) + ), + ) + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( - "call_function", torch.ops.higher_order.associative_scan, p_args, {} + "call_function", torch.ops.higher_order.scan, p_args, {} ), example_value=out_meta, ) @@ -1293,7 +1520,12 @@ def call_function( class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): def create_wrapped_node( - self, tx: "InstructionTranslator", args, kwargs, description + self, + tx: "InstructionTranslator", + args, + kwargs, + description, + under_activation_checkpoint=False, ): # See NOTE [HigherOrderOperator tracing design] for more details @@ -1309,6 +1541,7 @@ def create_wrapped_node( description, source_target=self.value, should_flatten_outputs=True, + under_activation_checkpoint=under_activation_checkpoint, ) body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) @@ -1431,6 +1664,80 @@ def call_function( ) +class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): + @raise_hard_error_if_graph_break( + reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." + ) + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + _check_supported_callable_arg(tx, args[0], "body_fn") + + # inputs + if len(args) != 3: + unimplemented( + f"Expected 3 arguments but got {len(args)}.\n" + f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n" + f"kwargs required to be provided explicitly." + ) + + if not isinstance(args[1], (ListVariable, TupleVariable)): + unimplemented( + f"Expected a tuple but got {args[1].python_type()}", + ) + operands = args[1].unpack_var_sequence(tx) + + if not isinstance(args[2], ConstDictVariable): + unimplemented( + f"Expected a dict but got {args[2].python_type()}", + ) + + if "hints" not in kwargs: + raise IncorrectUsage("hints_wrapper - key hints not provided") + + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + args[0], # function + operands, + args[2].as_python_constant(), + "hints_wrapper", + source_target=self.value, + should_flatten_outputs=True, + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = add_subgraph( + tx, + "hints_wrapper_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, + # all the arguments are lifted. + lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + p_args = (body_node, lifted_args, {}) + + p_kwargs = {} + # add hints into p_kwargs + p_kwargs["hints"] = kwargs["hints"].as_python_constant() + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, p_args, p_kwargs, flat_example_value, treespec + ) + + class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): def call_function( self, @@ -1571,7 +1878,11 @@ def call_function( treespec, checkpointed_gmod, ) = self.create_wrapped_node( - tx, args, gmod_kwargs, "torch.utils.checkpoint.checkpoint" + tx, + args, + gmod_kwargs, + "torch.utils.checkpoint.checkpoint", + under_activation_checkpoint=True, ) if context_fn is not None: checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn @@ -1646,6 +1957,26 @@ def call_function( ) +class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): + def call_function( + self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): """ Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace @@ -1904,6 +2235,7 @@ def bwd(ctx, grad, x): fwd_args, kwargs, "autograd.Function", + enable_grad=False, set_subgraph_inputs="semi_automatic", restore_side_effects=False, tracer=fwd_tracer, @@ -1975,6 +2307,20 @@ def is_strict_for(v: VariableTracker): # graph.output will append the output at the very end # This might be a behavior difference + # If users call ctx.mark_non_differentiable, we should capture these output tensors who + # are marked as non-differentiable and pass them to ApplyTemplate + # at torch._functorch.autograd_function.AutogradFunctionApply for reconstruction. + non_differentiable_idx = [] + if ctx.non_differentiable is not None: + non_differentiable_set = set(ctx.non_differentiable) + assert isinstance(fwd_out, variables.BaseListVariable) + for i, x in enumerate(fwd_out.items): + if ( + isinstance(x, variables.TensorVariable) + and x.as_proxy() in non_differentiable_set + ): + non_differentiable_idx.append(i) + # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) for node in fwd_graph.find_nodes(op="output"): fwd_graph.erase_node(node) @@ -2094,7 +2440,10 @@ def is_strict_for(v: VariableTracker): "call_function", autograd_function_apply, args=p_args, - kwargs={"args_tensor_mask": args_tensor_mask}, + kwargs={ + "args_tensor_mask": args_tensor_mask, + "non_differentiable_idx": non_differentiable_idx, + }, ), example_value=example_value, ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 6687611cf0aabd..1f8dac8811f53a 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -2,14 +2,17 @@ import itertools import operator -from typing import Dict, List, Optional, TYPE_CHECKING +import sys +from typing import Dict, List, Optional, TYPE_CHECKING, Union from .. import polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction from ..exc import ( handle_observed_exception, ObservedUserStopIteration, raise_observed_exception, unimplemented, + UserError, ) from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -30,9 +33,6 @@ def __init__(self, value, **kwargs) -> None: def __repr__(self) -> str: return f"ItertoolsVariable({self.value})" - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -52,14 +52,6 @@ def call_function( for item in itertools.product(*seqs): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) - elif ( - self.value is itertools.chain - and not kwargs - and all(arg.has_unpack_var_sequence(tx) for arg in args) - ): - seqs = [arg.unpack_var_sequence(tx) for arg in args] - items = list(itertools.chain.from_iterable(seqs)) - return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) elif self.value is itertools.accumulate: from .builtin import BuiltinVariable @@ -174,12 +166,6 @@ def retrieve_const_key(key): from_exc=e, ) return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) - elif self.value is itertools.islice: - from .builder import SourcelessBuilder - - return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.islice), args, kwargs - ) elif self.value is itertools.repeat: if len(args) < 2: return variables.RepeatIteratorVariable( @@ -214,6 +200,25 @@ def __init__(self, **kwargs) -> None: def next_variable(self, tx): unimplemented("abstract method, must implement") + # NOTE: only call when unpacking this iterator safely done eagerly! + # Normally, iterators are accessed lazily. + # Example of safe eager unpacking: list(map(f, seq)) + # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + result = [] + while True: + try: + result.append(self.next_variable(tx)) + except ObservedUserStopIteration: + handle_observed_exception(tx) + break + return result + + # don't call force_unpack_var_sequence since it can mutate + # IteratorVariable state! + def has_force_unpack_var_sequence(self, tx) -> bool: + return True + class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs) -> None: @@ -224,6 +229,18 @@ def __init__(self, item: VariableTracker, **kwargs) -> None: def next_variable(self, tx): return self.item + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("repeat"), + ] + ) + ) + codegen(self.item) + codegen.extend_output(create_call_function(1, False)) + class CountIteratorVariable(IteratorVariable): def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: @@ -237,10 +254,23 @@ def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: def next_variable(self, tx): assert self.mutable_local + old_item = self.item tx.output.side_effects.mutation(self) - next_item = self.item.call_method(tx, "__add__", [self.step], {}) - self.item = next_item - return self.item + self.item = self.item.call_method(tx, "__add__", [self.step], {}) + return old_item + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_python_module(itertools), + codegen.create_load_attr("count"), + ] + ) + ) + codegen(self.item) + codegen(self.step) + codegen.extend_output(create_call_function(2, False)) class CycleIteratorVariable(IteratorVariable): @@ -286,3 +316,160 @@ def next_variable(self, tx): return self.item else: raise_observed_exception(StopIteration, tx, self) + + +class ZipVariable(IteratorVariable): + """ + Represents zip(*iterables) + """ + + _nonvar_fields = { + "index", + "strict", + *IteratorVariable._nonvar_fields, + } + + def __init__( + self, + iterables: List[Union[List[VariableTracker], VariableTracker]], + strict: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + assert isinstance(iterables, list) + # can be list[Variable] or VariableTracker (with next_variable implemented) + self.iterables = iterables + self.index = 0 + self.strict = strict + + def python_type(self): + return zip + + def has_unpack_var_sequence(self, tx) -> bool: + return all( + isinstance(it, list) or it.has_unpack_var_sequence(tx) + for it in self.iterables + ) + + def unpack_var_sequence(self, tx) -> List["VariableTracker"]: + assert self.has_unpack_var_sequence(tx) + iterables = [] + for it in self.iterables: + if isinstance(it, list): + iterables.append(it[self.index :]) + else: + iterables.append(it.unpack_var_sequence(tx)) + kwargs = {"strict": self.strict} if self.strict else {} + zipped = zip(*iterables, **kwargs) + return [variables.TupleVariable(list(var)) for var in zipped] + + def next_variable(self, tx): + assert self.mutable_local + old_index = self.index + args = [] + + def get_item(it): + if isinstance(it, list): + if old_index >= len(it): + raise_observed_exception(StopIteration, tx, self) + return it[old_index] + else: + return it.next_variable(tx) + + try: + for idx, it in enumerate(self.iterables): + args.append(get_item(it)) + except ObservedUserStopIteration: + if self.strict: + if idx == 0: + # all other iterables should be exhausted + for it in self.iterables: + try: + get_item(it) + except ObservedUserStopIteration: + handle_observed_exception(tx) + continue + # no ObservedUserStopIteration - fall through to UserError + break + else: + # all iterables exhausted, raise original error + raise + handle_observed_exception(tx) + raise UserError( + ValueError, + "zip() has one argument of len differing from others", + ) from None + raise + + tx.output.side_effects.mutation(self) + self.index += 1 + return variables.TupleVariable(args) + + def reconstruct_items(self, codegen): + for it in self.iterables: + if isinstance(it, list): + remaining_items = it[self.index :] + codegen.foreach(remaining_items) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(remaining_items)) + ) + else: + codegen(it) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True + ) + self.reconstruct_items(codegen) + codegen.append_output( + create_instruction("BUILD_TUPLE", arg=len(self.iterables)) + ) + if sys.version_info >= (3, 10): + codegen.extend_output( + [ + codegen.create_load_const("strict"), + codegen.create_load_const(self.strict), + create_instruction("BUILD_MAP", arg=1), + create_instruction("CALL_FUNCTION_EX", arg=1), + ] + ) + else: + codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) + + +class MapVariable(ZipVariable): + """ + Represents map(fn, *iterables) + """ + + def __init__( + self, + fn: VariableTracker, + iterables: List[Union[List[VariableTracker], VariableTracker]], + **kwargs, + ) -> None: + super().__init__(iterables, **kwargs) + self.fn = fn + + def python_type(self): + return map + + def has_unpack_var_sequence(self, tx) -> bool: + return False + + def next_variable(self, tx): + args = super().next_variable(tx) + return self.fn.call_function(tx, args.items, {}) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True + ) + codegen(self.fn) + self.reconstruct_items(codegen) + codegen.extend_output( + [ + create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), + create_instruction("CALL_FUNCTION_EX", arg=0), + ] + ) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 83be6fa767392a..30916e0b699698 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -29,6 +29,7 @@ from .base import MutableLocal, VariableTracker from .constant import ConstantVariable from .functions import UserFunctionVariable, UserMethodVariable +from .iter import IteratorVariable if TYPE_CHECKING: @@ -334,11 +335,11 @@ def call_method( name == "extend" and self.mutable_local and args - and args[0].has_unpack_var_sequence(tx) + and args[0].has_force_unpack_var_sequence(tx) ): assert not kwargs (arg,) = args - seq = arg.unpack_var_sequence(tx) + seq = arg.force_unpack_var_sequence(tx) tx.output.side_effects.mutation(self) self.items.extend(seq) return ConstantVariable.create(None) @@ -422,17 +423,29 @@ def call_method( key, value = args tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): - if not value.has_unpack_var_sequence(tx): + if not value.has_force_unpack_var_sequence(tx): unimplemented( f"Missing dynamo support for expanding {value} into a list for slice assignment." ) - self.items[key.as_python_constant()] = value.unpack_var_sequence(tx) + self.items[key.as_python_constant()] = value.force_unpack_var_sequence( + tx + ) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) + def var_getattr(self, tx, name): + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is list: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.python_type() is not list: return super().call_hasattr(tx, name) @@ -454,7 +467,12 @@ def reconstruct(self, codegen): ) ) codegen.foreach(self.items) - codegen.extend_output(create_call_function(len(self.items), False)) + codegen.extend_output( + [ + create_instruction("BUILD_LIST", arg=len(self.items)), + *create_call_function(1, False), + ] + ) def call_method( self, @@ -477,11 +495,15 @@ def call_method( tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value return ConstantVariable.create(None) - elif name == "extendleft" and self.mutable_local: + elif ( + name == "extendleft" + and self.mutable_local + and args[0].has_force_unpack_var_sequence(tx) + ): assert not kwargs (arg,) = args - prefix = arg.unpack_var_sequence(tx) + prefix = arg.force_unpack_var_sequence(tx) prefix.reverse() tx.output.side_effects.mutation(self) self.items = prefix + list(self.items) @@ -525,6 +547,16 @@ def call_method( ) -> "VariableTracker": return super().call_method(tx, name, args, kwargs) + def var_getattr(self, tx, name): + if name == "__class__": + source = AttrSource(self.source, name) if self.source else None + class_type = self.python_type() + if class_type is tuple: + return variables.BuiltinVariable(class_type, source=source) + else: + return variables.UserDefinedClassVariable(class_type, source=source) + return super().var_getattr(tx, name) + def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if self.python_type() is not tuple: return super().call_hasattr(tx, name) @@ -730,7 +762,7 @@ def check_and_create_method(): if name not in fields: method = check_and_create_method() if not method: - super().var_getattr(tx, name) + return super().var_getattr(tx, name) return method return self.items[fields.index(name)] @@ -782,10 +814,10 @@ def var_getattr(self, tx: "InstructionTranslator", name): return self.items[fields.index(name)] -class ListIteratorVariable(VariableTracker): +class ListIteratorVariable(IteratorVariable): _nonvar_fields = { "index", - *VariableTracker._nonvar_fields, + *IteratorVariable._nonvar_fields, } def __init__(self, items, index: int = 0, **kwargs) -> None: @@ -836,6 +868,9 @@ def as_python_constant(self): def unpack_var_sequence(self, tx): return list(self.items[self.index :]) + def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: + return self.unpack_var_sequence(tx) + def reconstruct(self, codegen): remaining_items = self.items[self.index :] codegen.foreach(remaining_items) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 0e3ae4a77a3632..663ff5b20a6d0b 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -35,7 +35,13 @@ set_example_value, ) from .base import VariableTracker -from .functions import NestedUserFunctionVariable, UserFunctionVariable, wrap_bound_arg +from .functions import ( + NestedUserFunctionVariable, + UserFunctionVariable, + UserMethodVariable, + wrap_bound_arg, +) +from .nn_module import UnspecializedNNModuleVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -43,6 +49,10 @@ from torch._dynamo.symbolic_convert import InstructionTranslator +class NO_SUCH_SUBOBJ: + pass + + class SuperVariable(VariableTracker): _nonvar_fields = { "specialized", @@ -93,22 +103,34 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): type_to_use_source = self.objvar.source source = None - if self.objvar.source is not None: - # Walk the mro tuple to find out the actual class where the - # attribute resides. - search_mro = type_to_use.__mro__ + resolved_class = None + resolved_attr = None + search_mro = type_to_use.__mro__ + + try: start_index = search_mro.index(search_type) + 1 - for index in range(start_index, len(search_mro)): - if hasattr(search_mro[index], name): + except ValueError: + # Corner case where the typevar is not in the mro of the objvar + # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844 + return getattr(super(search_type, type_to_use), name), None + # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812 + # super has its getattro implementation. The key point is that instead of calling getattr, it checks the + # attribute in the class __dict__ + for index in range(start_index, len(search_mro)): + # Dont call getattr, just check the __dict__ of the class + if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ): + if resolved_getattr is not NO_SUCH_SUBOBJ: # Equivalent of something like type(L['self']).__mro__[1].attr_name - source = AttrSource( - GetItemSource(AttrSource(type_to_use_source, "__mro__"), index), - name, - ) - break + if type_to_use_source: + source = AttrSource( + GetItemSource( + AttrSource(type_to_use_source, "__mro__"), index + ), + name, + ) + return resolved_getattr, source - # TODO(jansel): there is a small chance this could trigger user code, prevent that - return getattr(super(search_type, type_to_use), name), source + unimplemented("Unable to resolve super getattr") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": # Check if getattr is a constant. If not, delay the actual work by @@ -156,12 +178,17 @@ def call_method( return tx.output.side_effects.track_object_new_from_user_defined_class( self.objvar ) - elif name == "__new__" and isinstance(inner_fn, types.FunctionType): - # __new__ is a staticmethod object, but accessing __new__ from the super object, as done in - # _resolved_getattr_and_source, results in a function object. If not specialized here, it will try to add - # the `self` arg and fail bind arg matching later. + elif isinstance(inner_fn, staticmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): return variables.UserFunctionVariable( - inner_fn, source=source + inner_fn.__func__, source=source + ).call_function(tx, args, kwargs) + elif isinstance(inner_fn, classmethod) and isinstance( + inner_fn.__func__, types.FunctionType + ): + return variables.UserMethodVariable( + inner_fn.__func__, self.objvar, source=source ).call_function(tx, args, kwargs) elif isinstance(inner_fn, types.FunctionType): return variables.UserFunctionVariable( @@ -351,6 +378,7 @@ class InspectSignatureVariable(VariableTracker): _nonvar_fields = { "signature", + "parameters", *VariableTracker._nonvar_fields, } @@ -366,18 +394,29 @@ def __init__(self, inspected: VariableTracker, **kwargs) -> None: super().__init__(**kwargs) self.inspected = inspected - if isinstance(self.inspected, UserFunctionVariable): - self.fn = self.inspected.get_function() - else: - self.fn = self.inspected.as_python_constant() + try: + if hasattr(self.inspected, "get_function"): + self.fn = self.inspected.get_function() + elif isinstance(self.inspected, UnspecializedNNModuleVariable): + self.fn = self.inspected.value + else: + self.fn = self.inspected.as_python_constant() + except NotImplementedError: + unimplemented("inspect.signature with non-constant function") + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items()) + if isinstance(self.inspected, UserMethodVariable): + self.parameters = self.parameters[1:] def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "parameters": return variables.ConstDictVariable( { - variables.ConstantVariable.create(name): InspectParameterVariable() - for name in self.inspected.inspect_parameter_names() + variables.ConstantVariable.create( + param[0] + ): InspectParameterVariable(param[1]) + for param in self.parameters }, user_cls=dict, ) @@ -433,7 +472,24 @@ def reconstruct(self, codegen): class InspectParameterVariable(VariableTracker): - """This is not implemented, if used will graph break.""" + """represents inspect.Parameter(...)""" + + def __init__(self, value, **kwargs) -> None: + super().__init__(**kwargs) + self.value = value + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from .builder import SourcelessBuilder, VariableBuilder + + try: + attr_value = getattr(self.value, name) + if self.source: + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(attr_value) + else: + return SourcelessBuilder.create(tx, attr_value) + except AttributeError: + unimplemented(f"getattr({self.value}, {name})") class InspectBoundArgumentsVariable(VariableTracker): @@ -776,6 +832,7 @@ def __init__( proxy=None, saved_tensors=None, needs_input_grad=None, + non_differentiable=None, **kwargs, ) -> None: super().__init__(value=value, value_type=value_type, **kwargs) @@ -783,6 +840,7 @@ def __init__( self.proxy = proxy self.saved_tensors = saved_tensors self.needs_input_grad = needs_input_grad + self.non_differentiable = non_differentiable @staticmethod def create(tx: "InstructionTranslator", args=None, kwargs=None): @@ -825,6 +883,11 @@ def call_method( ) -> "VariableTracker": if name == "__setattr__": return super().call_method(tx, name, args, kwargs) + elif name == "mark_non_differentiable": + assert len(kwargs) == 0 + self.non_differentiable = proxy_args_kwargs(args, {})[0] + return variables.ConstantVariable.create(None) + if name != "save_for_backward": unimplemented(f"autograd.Function context method: {name}") if self.saved_tensors is None: @@ -844,7 +907,7 @@ def call_method( return variables.ConstantVariable.create(None) def var_getattr(self, tx: "InstructionTranslator", name): - if name == "save_for_backward": + if name in ["save_for_backward", "mark_non_differentiable"]: return LambdaVariable( lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) ) @@ -1132,9 +1195,6 @@ def call_method( ) unimplemented("typing") - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -1252,9 +1312,6 @@ def call_method( ) -> "VariableTracker": unimplemented("numpy") - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -1429,9 +1486,6 @@ def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index ea941890b5d129..5cb02c077cbc30 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -795,6 +795,20 @@ def method___len__(self): tx = InstructionTranslator.current_tx() return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + def method_addcmul_(self, tensor1, tensor2, *, value=None): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + if value is not None: + from .. import polyfills + from .builder import SourcelessBuilder + + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.addcmul_inplace), + [self, tensor1, tensor2, value], + {}, + ) + def method___setitem__(self, key, value): def has_bool_key(v): if isinstance(v, TensorVariable): @@ -992,12 +1006,15 @@ def _register_hook_trampoline(tensor, bw_state): from .builder import wrap_fx_proxy + self_proxy = self.as_proxy() + self_proxy.node.meta["has_backward_hook"] = True + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", _register_hook_trampoline, - (self.as_proxy(), bw_state_proxy), + (self_proxy, bw_state_proxy), {}, ), ) @@ -1314,9 +1331,6 @@ def call_function( def as_python_constant(self): return self.value - def python_type(self): - return type(self.value) - class UntypedStorageVariable(VariableTracker): _nonvar_fields = { diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1d86c16a6e09ec..c1e0dec0fbc415 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -15,6 +15,7 @@ from torch._guards import TracingContext from torch._logging import warning_once from torch._streambase import _StreamBase +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type from .. import config, polyfills, variables from ..codegen import PyCodegen @@ -26,7 +27,7 @@ from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder, install_guard -from ..source import SyntheticLocalSource +from ..source import CallFunctionNoArgsSource, SyntheticLocalSource from ..utils import ( check_unspec_or_constant_args, guard_if_dyn, @@ -88,6 +89,8 @@ torch.autograd.graph.disable_saved_tensors_hooks, torch.cpu.amp.autocast_mode.autocast, torch.cuda.amp.autocast_mode.autocast, + torch.nn.attention.sdpa_kernel, + torch.nn.attention._sdpa_kernel_variadic, ] ) @@ -99,6 +102,10 @@ ] ) +constant_fold_functions_need_guards = [ + torch.cuda.current_device, +] + constant_fold_functions = [ torch._assert, torch._utils._get_device_index, @@ -119,7 +126,7 @@ torch.promote_types, torch._C._get_privateuse1_backend_name, torch.autograd._is_checkpoint_valid, -] +] + constant_fold_functions_need_guards if torch.distributed.is_available(): constant_fold_functions.extend( [ @@ -129,6 +136,7 @@ ] ) # Convert to dict for O(1) access times +constant_fold_functions_need_guards = dict.fromkeys(constant_fold_functions_need_guards) constant_fold_functions = dict.fromkeys(constant_fold_functions) @@ -148,16 +156,32 @@ bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) +@functools.lru_cache(None) +def get_overridable_functions(): + from itertools import chain + + from torch.overrides import get_overridable_functions as get_overridable_functions_ + + funcs = set(chain(*get_overridable_functions_().values())) + more = { + torch.ones, + torch.ones_like, + torch.zeros, + torch.zeros_like, + torch.empty, + torch.full, + } + funcs.update(more) + return funcs + + class BaseTorchVariable(VariableTracker): """common base for all torch.* functions, classes, modules and other things""" @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) - return cls( - value, - source=source, - ) + return cls(value, source=source) def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) @@ -176,9 +200,6 @@ def reconstruct(self, codegen): def as_proxy(self): return self.value - def python_type(self): - return type(self.value) - def as_python_constant(self): return self.value @@ -229,6 +250,7 @@ def call_function( GradModeVariable, InferenceModeVariable, JvpIncrementNestingCtxManagerVariable, + SDPAKernelVariable, SetFwdGradEnabledContextManager, StreamVariable, VmapIncrementNestingCtxManagerVariable, @@ -329,6 +351,14 @@ def call_function( return FSDPParamGroupUseTrainingStateVariable.create( tx, args[0], args[1].as_python_constant() ) + elif self.value is torch.nn.attention.sdpa_kernel: + assert len(args) == 1 or (len(kwargs) == 1 and "backends" in kwargs) + backends = args[0] if len(args) == 1 else kwargs["backends"] + return SDPAKernelVariable.create(tx, backends.as_python_constant()) + elif self.value is torch.nn.attention._sdpa_kernel_variadic: + return SDPAKernelVariable.create( + tx, [arg.as_python_constant() for arg in args] + ) return super().call_function(tx, args, kwargs) @@ -446,6 +476,14 @@ def handle_numel(self, tx: "InstructionTranslator", input): # Workaround dynamic shapes issue return input.call_method(tx, "numel", [], {}) + @register(torch.compile) + def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) == 1: + # torch.compile is a no-op in dynamo + return args[0] + + unimplemented("torch.compile is used as a decorator in the compiled frame") + @register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD) def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input): assert isinstance(input, TensorVariable) @@ -581,6 +619,30 @@ def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): tx, [args[0], result], {} ) + @register(torch._foreach_lerp_) + def handle_inplace_foreach_lerp_scalar( + self, tx: "InstructionTranslator", *args, **kwargs + ): + if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace), + args, + kwargs, + ) + + @register(torch._foreach_pow) + def handle_foreach_pow_scalar( + self, tx: "InstructionTranslator", *args, **kwargs + ): + # In eager it's more performant to call item() from within the C op implementation + # in compile, it's more performant to not graph break. + if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs: + return tx.inline_user_function_return( + SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar), + args, + kwargs, + ) + @register(torch._assert) def handle_assert(self, tx: "InstructionTranslator", condition, message): if (condition.is_python_constant() and condition.as_python_constant()) or ( @@ -602,7 +664,6 @@ def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): ) if DistributedVariable.is_available(): - from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ( _get_group_size_by_name, _get_group_tag, @@ -610,6 +671,7 @@ def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs): _resolve_group_name_by_ranks_and_tag, get_process_group_ranks, ) + from torch.distributed.tensor import DTensor @register( _get_group_size_by_name, @@ -652,10 +714,19 @@ def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function args_as_value = [x.as_python_constant() for x in args[1:]] - kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} - - def fn_with_prim_types(x): - return self.value(x, *args_as_value, **kwargs_as_value) + kwargs_as_value = { + k: v.as_python_constant() + for k, v in kwargs.items() + if k not in ["shape", "stride"] + } + kwargs_to_be_proxied = { + k: kwargs[k] for k in ["shape", "stride"] if k in kwargs + } + + def fn_with_prim_types(x, shape=None, stride=None): + return self.value( + x, *args_as_value, **kwargs_as_value, shape=shape, stride=stride + ) # attach the same function name for better debugging fn_with_prim_types.__name__ = "prim " + self.value.__name__ @@ -665,7 +736,10 @@ def fn_with_prim_types(x): proxy=tx.output.create_proxy( "call_function", fn_with_prim_types, - *proxy_args_kwargs([args[0]], {}), + *proxy_args_kwargs( + [args[0]], + kwargs_to_be_proxied, + ), ), ) @@ -751,10 +825,10 @@ def handle_pop_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - if not tx.symbolic_torch_function_mode_stack: + if not tx.symbolic_torch_function_state.mode_stack: raise unimplemented("Popping from an empty torch function mode stack") TorchFunctionModeStackVariable.register_mutation(tx) - return tx.symbolic_torch_function_mode_stack.pop() + return tx.symbolic_torch_function_state.pop_torch_function_mode() @register(torch._C._push_on_torch_function_stack) def handle_push_torch_function( @@ -762,7 +836,7 @@ def handle_push_torch_function( ): assert len(args) == 1 and not kwargs TorchFunctionModeStackVariable.register_mutation(tx) - tx.symbolic_torch_function_mode_stack.append(args[0]) + tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) @@ -770,7 +844,16 @@ def handle_len_torch_function( self, tx: "InstructionTranslator", *args, **kwargs ): assert not args and not kwargs - return ConstantVariable.create(len(tx.symbolic_torch_function_mode_stack)) + return ConstantVariable.create( + len(tx.symbolic_torch_function_state.mode_stack) + ) + + @register(torch._C._get_function_stack_at) + def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): + assert len(args) == 1 and not kwargs + ind = args[0].as_python_constant() + assert ind >= 0 and ind < len(tx.symbolic_torch_function_state.mode_stack) + return tx.symbolic_torch_function_state.mode_stack[ind] @register(torch.set_default_device) def handle_set_default_device( @@ -789,7 +872,7 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return None + return ConstantVariable.create(None) return handlers @@ -802,9 +885,16 @@ def call_function( from . import ConstantVariable, SymNodeVariable, TensorVariable from .builder import wrap_fx_proxy + if self.torch_function_override_enabled(tx, args, kwargs): + return dispatch_torch_function(tx, self, args, kwargs) + if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): + # constant fold functions need to be guarded. + if self.value in constant_fold_functions_need_guards: + source = CallFunctionNoArgsSource(self.source) + install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold return ConstantVariable.create( self.as_python_constant()( @@ -819,147 +909,144 @@ def call_function( if result: return result - if can_dispatch_torch_function(tx, args, kwargs): - return dispatch_torch_function(tx, self, args, kwargs) - else: - any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) + any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) - all_ints_or_floats = all( - isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) - for x in args - ) - if ( - getattr(self.value, "__module__", "") == "torch" - and self.value.__name__ in bin_ops - and any_symints_or_symfloats - and all_ints_or_floats - ): - msg = f"""\ + all_ints_or_floats = all( + isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) + for x in args + ) + if ( + getattr(self.value, "__module__", "") == "torch" + and self.value.__name__ in bin_ops + and any_symints_or_symfloats + and all_ints_or_floats + ): + msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ - log.warning(msg) - unimplemented(msg) - - # TODO(voz): Replace w/ dynamic shape rewrite table. - # Ideally, we would be able to do this at ctor time, but alas we need a combination - # of value + args to determine this. - fn_ = self.value - if any_symints_or_symfloats: - torch_sym_op = f"_sym_{self.value.__name__}" - if getattr(self.value, "__module__", None) == "math" and hasattr( - torch, torch_sym_op - ): - fn_ = getattr(torch, torch_sym_op) - - fake_out_shape = None - if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): - # Calling fake tensor propagation can mutate the out= tensor in - # tx.output.tracked_fakes. tracked_fakes are used to apply - # symbolic_shape guards. Mutating them destroys the information - # prior to tracing, which is essential for creating right - # guards. So save the shape now, and check later if it has - # changed. If it has, graph break. - fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape - - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - *proxy_args_kwargs(args, kwargs), - ), - ) - - if ( - isinstance(tensor_variable, TensorVariable) - and "requires_grad" in kwargs - and kwargs["requires_grad"].as_python_constant() + log.warning(msg) + unimplemented(msg) + + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any_symints_or_symfloats: + torch_sym_op = f"_sym_{self.value.__name__}" + if getattr(self.value, "__module__", None) == "math" and hasattr( + torch, torch_sym_op ): - unimplemented( - """factory functions that return tensors that require grad are not supported. + fn_ = getattr(torch, torch_sym_op) + + fake_out_shape = None + if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable): + # Calling fake tensor propagation can mutate the out= tensor in + # tx.output.tracked_fakes. tracked_fakes are used to apply + # symbolic_shape guards. Mutating them destroys the information + # prior to tracing, which is essential for creating right + # guards. So save the shape now, and check later if it has + # changed. If it has, graph break. + fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape + + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + *proxy_args_kwargs(args, kwargs), + ), + ) + + if ( + isinstance(tensor_variable, TensorVariable) + and "requires_grad" in kwargs + and kwargs["requires_grad"].as_python_constant() + ): + unimplemented( + """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" - ) + ) - if "out" in kwargs and not ( - isinstance(kwargs["out"], variables.ConstantVariable) - and kwargs["out"].as_python_constant() is None - ): - # out variants of torch operators like torch.sort and - # torch.sigmoid mutate the tensors in the out field. Track such - # tensors and rewrite the symbolic locals. - if isinstance(tensor_variable, TupleVariable): - assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) - output_tensor_names = [ - tx.find_symbolic_locals_name(x) for x in kwargs["out"].items - ] - for idx, name in enumerate(output_tensor_names): - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable.items[idx] - for out_tensor, result_tensor in zip( - kwargs["out"].items, tensor_variable.items - ): - if ( - out_tensor.source - and out_tensor in tx.output.graphargs - and isinstance(out_tensor, variables.TensorVariable) - and isinstance(result_tensor, variables.TensorVariable) - and out_tensor.size != result_tensor.size - ): - # It's hard to get out variants with resizing on graph inputs work - # properly across dynamo/aot/inductor, just fall back. - unimplemented("out variants with resizing on graph inputs") - elif isinstance(tensor_variable, TensorVariable): - assert isinstance(kwargs["out"], TensorVariable) - assert "example_value" in kwargs["out"].proxy.node.meta - fake_tensor = tensor_variable.proxy.node.meta["example_value"] - fake_out = kwargs["out"].proxy.node.meta["example_value"] + if "out" in kwargs and not ( + isinstance(kwargs["out"], variables.ConstantVariable) + and kwargs["out"].as_python_constant() is None + ): + # out variants of torch operators like torch.sort and + # torch.sigmoid mutate the tensors in the out field. Track such + # tensors and rewrite the symbolic locals. + if isinstance(tensor_variable, TupleVariable): + assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) + output_tensor_names = [ + tx.find_symbolic_locals_name(x) for x in kwargs["out"].items + ] + for idx, name in enumerate(output_tensor_names): + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable.items[idx] + for out_tensor, result_tensor in zip( + kwargs["out"].items, tensor_variable.items + ): if ( - kwargs["out"].source - and kwargs["out"] in tx.output.graphargs - and fake_out_shape != fake_tensor.shape + out_tensor.source + and out_tensor in tx.output.graphargs + and isinstance(out_tensor, variables.TensorVariable) + and isinstance(result_tensor, variables.TensorVariable) + and out_tensor.size != result_tensor.size ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") + elif isinstance(tensor_variable, TensorVariable): + assert isinstance(kwargs["out"], TensorVariable) + assert "example_value" in kwargs["out"].proxy.node.meta + fake_tensor = tensor_variable.proxy.node.meta["example_value"] + fake_out = kwargs["out"].proxy.node.meta["example_value"] + if ( + kwargs["out"].source + and kwargs["out"] in tx.output.graphargs + and fake_out_shape != fake_tensor.shape + ): + # It's hard to get out variants with resizing on graph inputs work + # properly across dynamo/aot/inductor, just fall back. + unimplemented("out variants with resizing on graph inputs") + if not torch._prims_common.is_contiguous(fake_out): + # It's difficult to handle strides correctly in functionalization + # when calling an out= op with a non-contiguous out argument + unimplemented( + "out= op was called where output tensor was non-contiguous" + ) + name = tx.find_symbolic_locals_name(kwargs["out"]) + if name in tx.symbolic_locals: + tx.symbolic_locals[name] = tensor_variable + elif ( + isinstance(tensor_variable, ConstantVariable) + and tensor_variable.value is None + ): + # Handle out-variant custom ops that return None. + if isinstance(kwargs["out"], TensorVariable): + assert "example_value" in kwargs["out"].proxy.node.meta + fake_out = kwargs["out"].proxy.node.meta["example_value"] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) - name = tx.find_symbolic_locals_name(kwargs["out"]) - if name in tx.symbolic_locals: - tx.symbolic_locals[name] = tensor_variable - elif ( - isinstance(tensor_variable, ConstantVariable) - and tensor_variable.value is None - ): - # Handle out-variant custom ops that return None. - if isinstance(kwargs["out"], TensorVariable): - assert "example_value" in kwargs["out"].proxy.node.meta - fake_out = kwargs["out"].proxy.node.meta["example_value"] + elif isinstance(kwargs["out"], ListVariable): + for idx, x in enumerate(kwargs["out"].items): + assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] + fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( - "out= op was called where output tensor was non-contiguous" + "out= op was called where some of the output tensors were non-contiguous" ) - elif isinstance(kwargs["out"], ListVariable): - for idx, x in enumerate(kwargs["out"].items): - assert "example_value" in x.proxy.node.meta # type: ignore[attr-defined] - fake_out = x.proxy.node.meta["example_value"] # type: ignore[attr-defined] - if not torch._prims_common.is_contiguous(fake_out): - # It's difficult to handle strides correctly in functionalization - # when calling an out= op with a non-contiguous out argument - unimplemented( - "out= op was called where some of the output tensors were non-contiguous" - ) - else: - unimplemented(f"out variant of {type(kwargs['out'])}") + else: + unimplemented(f"out variant of {type(kwargs['out'])}") - return tensor_variable + return tensor_variable def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" @@ -1007,6 +1094,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) + if is_traceable_wrapper_subclass_type(data.class_type): + unimplemented("Parameter constructor with tensor subclass NYI") + if not can_convert_to_tracable_parameter(): unimplemented("Workaround for issues with nn_parameter construction") @@ -1084,3 +1174,12 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad source ) return result + + def torch_function_override_enabled(self, tx, args, kwargs): + return ( + self.get_function() in get_overridable_functions() + or isinstance( + self.get_function(), + (torch._ops.OpOverload, torch._ops.OpOverloadPacket), + ) + ) and can_dispatch_torch_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4b3188507fe4b5..ffb3d27d4d703b 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,20 +1,37 @@ # mypy: ignore-errors +import collections +import contextlib +import functools import inspect -from typing import Dict, List, TYPE_CHECKING +from typing import Deque, Dict, List, TYPE_CHECKING +import torch._C import torch.utils._pytree as pytree from torch._guards import Source -from torch.overrides import _get_overloaded_args, get_default_nowrap_functions +from torch.overrides import ( + _get_overloaded_args, + get_default_nowrap_functions, + TorchFunctionMode, +) from torch.utils._device import DeviceContext from ..exc import unimplemented from ..guards import GuardBuilder, install_guard +from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource -from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter +from ..utils import ( + class_has_getattribute, + clear_torch_function_mode_stack, + get_safe_global_name, + has_torch_function, + is_tensor_base_attr_getter, + set_torch_function_mode_stack, +) from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable +from .ctx_manager import GenericContextWrappingVariable +from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable from .user_defined import UserDefinedObjectVariable @@ -52,11 +69,99 @@ if is_tensor_base_attr_getter(fn) ] -# Today set default device is placed in the graph and guarded on separately -# so we should not trace through it. In the future we can trace it once -# mode tracing is implemented and not put in the graph, but this is more -# of a BE project and can be evaluated later -IGNORED_MODES = {DeviceContext} + +@functools.lru_cache(None) +def get_prev_stack_var_name(): + from ..bytecode_transformation import unique_id + + return unique_id("___prev_torch_function_mode_stack") + + +# Used to clear/restore the python torch function mode stack and temporarily restore it as needed +class TorchFunctionModeStackStateManager: + def __init__(self): + self.stack = [] + + def __enter__(self): + self.stack = torch.overrides._get_current_function_mode_stack() + clear_torch_function_mode_stack() + + def __exit__(self, exc_type, exc_value, traceback): + set_torch_function_mode_stack(self.stack) + self.stack = [] + + @contextlib.contextmanager + def temp_restore_stack(self): + prev = torch.overrides._get_current_function_mode_stack() + set_torch_function_mode_stack(self.stack) + try: + yield + finally: + set_torch_function_mode_stack(prev) + + +torch_function_mode_stack_state_mgr = TorchFunctionModeStackStateManager() + + +class SymbolicTorchFunctionState: + def __init__(self, py_stack): + # This is annoyingly complicated because of how the torch function subclass + mode C API was designed + # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass + # These are their definitions: + # 1) torch._C._is_torch_function_enabled indicates that neither of the above knobs have been entered + # (if either are entered, this will be False) + # 2) torch._C._is_torch_function_mode_enabled indicates that either the torch mode stack is empty OR + # torch._C.DisableTorchFunction has been entered + # To disambiguate these and keep myself sane I added a C API to check whether all torch function + # concepts (modes and subclasses) are enabled. + # This only returns true iff we have not entered torch._C.DisableTorchFunction and allows us to separate + # the stack length from the enablement state of torch function modes. + # This is important because now if a mode is pushed while dynamo is tracing, we know whether + # or not torch function modes are enabled and whether we should trace it. + self.torch_function_subclass_enabled = torch._C._is_torch_function_enabled() + + # This differs from the C API of the same name + # this will only be false iff we have entered torch._C.DisableTorchFunction + # and does not take into account the mode stack length, while the C API bundles these + # two concepts + self.torch_function_mode_enabled = ( + not torch._C._is_torch_function_all_disabled() + ) + + self.cur_mode = None + + TorchFunctionModeStackVariable.reset() + + self.mode_stack: Deque[TorchFunctionModeVariable] = collections.deque() + + for i, val in enumerate(py_stack): + self.mode_stack.append( + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + ) + + def in_torch_function_mode(self): + return len(self.mode_stack) > 0 + + def pop_torch_function_mode(self): + return self.mode_stack.pop() + + def push_torch_function_mode(self, mode_var): + self.mode_stack.append(mode_var) + + def call_torch_function_mode(self, tx, fn, types, args, kwargs): + with self._pop_mode_for_inlining() as cur_mode: + return cur_mode.call_torch_function(tx, fn, types, args, kwargs) + + @contextlib.contextmanager + def _pop_mode_for_inlining(self): + old_mode = self.cur_mode + self.cur_mode = self.pop_torch_function_mode() + try: + yield self.cur_mode + finally: + mode = self.cur_mode + self.cur_mode = old_mode + self.push_torch_function_mode(mode) class TorchFunctionModeStackVariable(VariableTracker): @@ -88,19 +193,20 @@ def reset(cls): def register_mutation(cls, tx: "InstructionTranslator"): if cls.stack_value_singleton not in tx.output.side_effects: var = cls( - source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack + source=Source(), + symbolic_stack=tx.symbolic_torch_function_state.mode_stack, ) tx.output.side_effects.track_mutable(cls.stack_value_singleton, var) tx.output.side_effects.mutation(var) @classmethod def register_device_context_insertion(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return else: cls.offset += 1 - tx.symbolic_torch_function_mode_stack.insert( + stack.insert( 0, TorchFunctionModeVariable( None, source=TorchFunctionModeStackSource(-cls.offset) @@ -109,7 +215,7 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): @classmethod def clear_default_device(cls, tx: "InstructionTranslator"): - stack = tx.symbolic_torch_function_mode_stack + stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @@ -123,24 +229,88 @@ def get_mode_index(cls, ind): return ind + cls.offset -class TorchFunctionModeVariable(ContextWrappingVariable): - def __init__(self, value, **kwargs): - super().__init__(value, **kwargs) - self.value = value - +class TorchFunctionModeVariable(GenericContextWrappingVariable): @staticmethod - def get_global_mangled_name(tx, val): - return get_safe_global_name( - tx, f"__torch_function_mode_{val.__class__.__name__}", val + def is_supported_torch_function_mode(ty): + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the funtion across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ ) + def __init__(self, value, source=None, **kwargs): + if value is not None: + super().__init__(value, **kwargs) + self.value = value + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source + def reconstruct(self, codegen): - # We don't support locally created torch function modes yet + # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) - def _call_func(self, tx, values): - unimplemented("torch function mode context manager is not supported yet") + def module_name(self): + return self.value.__module__ + + def fn_name(self): + return type(self.value).__name__ + + def python_type(self): + return type(self.value) + + def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + return call_torch_function( + tx, + self, + build_torch_function_fn(tx, self.value, self.source), + fn, + types, + args, + kwargs, + ) + + def enter(self, tx): + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args): + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen): + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self): + return True + + def exit_on_graph_break(self): + return False def _get_all_args(args, kwargs): @@ -231,9 +401,13 @@ def build_torch_function_fn(tx: "InstructionTranslator", value, source): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): - return tx.output.torch_function_enabled and any( + has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) + tf_state = tx.symbolic_torch_function_state + return (has_overridden_args and tf_state.torch_function_subclass_enabled) or ( + tf_state.torch_function_mode_enabled and tf_state.in_torch_function_mode() + ) def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): @@ -245,11 +419,20 @@ def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): _get_subclass_type, ) + types = TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]) + + if tx.symbolic_torch_function_state.in_torch_function_mode(): + res = tx.symbolic_torch_function_state.call_torch_function_mode( + tx, fn, types, args, kwargs + ) + if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): + return res + for arg in overloaded_args: res = arg.call_torch_function( tx, fn, - TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]), + types, args, kwargs, ) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index a254bf97cf1c61..32057c838d74c8 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -2,12 +2,14 @@ import collections import contextlib +import dataclasses import enum import functools import inspect import itertools import random import sys +import threading import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING @@ -15,6 +17,7 @@ import torch._dynamo.config import torch.nn from torch._guards import TracingContext +from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type from .. import polyfills, variables from ..bytecode_transformation import create_call_function @@ -39,6 +42,7 @@ check_constant_args, get_custom_getattr, has_torch_function, + is_frozen_dataclass, is_namedtuple_cls, is_utils_checkpoint, is_wrapper_or_member_descriptor, @@ -79,11 +83,6 @@ def is_forbidden_context_manager(ctx): from _pytest.python_api import RaisesContext from _pytest.recwarn import WarningsChecker - # TODO mlazos: Temporary to get this stack to pass - # remove in subsequent PR - from torch.overrides import BaseTorchFunctionMode - - f_ctxs.append(BaseTorchFunctionMode) f_ctxs.append(RaisesContext) f_ctxs.append(WarningsChecker) except ImportError: @@ -113,9 +112,6 @@ def __init__(self, value, **kwargs) -> None: def as_python_constant(self): return self.value - def python_type(self): - return type(self.value) - def as_proxy(self): return self.value @@ -197,6 +193,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke else: return SourcelessBuilder.create(tx, func) elif isinstance(obj, classmethod): + if isinstance(obj.__func__, property): + return variables.UserFunctionVariable(obj.__func__.fget).call_function( + tx, [self], {} + ) return variables.UserMethodVariable(obj.__func__, self, source=source) elif isinstance(obj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static(dict, "fromkeys") @@ -379,8 +379,8 @@ def call_function( elif self.value is collections.deque and not kwargs: if len(args) == 0: items = [] - elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): - items = args[0].unpack_var_sequence(tx) + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) else: unimplemented("deque() with more than 1 arg not supported") return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) @@ -413,15 +413,25 @@ def call_function( and self.source and not is_forbidden_context_manager(self.value) ): - # import here to avoid an unfortunate circular dependency. + from torch.overrides import TorchFunctionMode + from .ctx_manager import GenericContextWrappingVariable + from .torch_function import TorchFunctionModeVariable + + if issubclass( + self.value, TorchFunctionMode + ) and TorchFunctionModeVariable.is_supported_torch_function_mode( + self.value + ): + var_cls = TorchFunctionModeVariable + else: + var_cls = GenericContextWrappingVariable cm_obj = tx.output.side_effects.track_object_new( - self.source, self.value, GenericContextWrappingVariable, {} + self.source, self.value, var_cls, {} ) cm_obj.call_method(tx, "__init__", args, kwargs) return cm_obj - elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one @@ -452,6 +462,40 @@ def call_function( assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) + elif is_frozen_dataclass(self.value) and self.is_standard_new(): + from .builder import SourcelessBuilder + + fields = dataclasses.fields(self.value) + items = list(args) + items.extend([None] * (len(fields) - len(items))) + + default_kwargs = {} + for field, var_tracker in zip(fields, items): + if var_tracker is None: + if field.name in kwargs: + var_tracker = kwargs[field.name] + else: + if not field.init: + continue + + if field.default is not dataclasses.MISSING: + var_tracker = SourcelessBuilder.create(tx, field.default) + elif field.default_factory is not dataclasses.MISSING: + factory_fn = SourcelessBuilder.create( + tx, field.default_factory + ) + var_tracker = factory_fn.call_function(tx, [], {}) + else: + # if we are subclass, the constructor could possibly + # be missing args + continue + + default_kwargs[field.name] = var_tracker + kwargs.update(default_kwargs) + + var = tx.output.side_effects.track_object_new_from_user_defined_class(self) + var.call_method(tx, "__init__", args, kwargs) + return var elif ( self.is_standard_new() and SideEffects.cls_supports_mutation_side_effects(self.value) @@ -476,7 +520,10 @@ def call_function( user_cls_source=self.source, mutable_local=MutableLocal(), ) - elif self.value in self._in_graph_classes(): + elif ( + self.value in self._in_graph_classes() + or is_traceable_wrapper_subclass_type(self.value) + ): # torch.LongTensor cannot accept a list of FakeTensors. # So we stack the list of FakeTensors instead. if ( @@ -674,7 +721,7 @@ def call_method( if method is object.__init__: return ConstantVariable.create(None) - if is_standard_setattr(method): + if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) # [NOTE] OrderedDict, dict subtypes must always have source @@ -712,7 +759,7 @@ def call_method( assert not (args or kwargs) items = [] keys = self.call_method(tx, "keys", [], {}) - for key in keys.unpack_var_sequence(tx): + for key in keys.force_unpack_var_sequence(tx): items.append( TupleVariable( [key, self.odict_getitem(tx, key)], @@ -772,7 +819,7 @@ def method_setattr_standard(self, tx: "InstructionTranslator", name, value): def needs_slow_setattr(self): return not is_standard_setattr( inspect.getattr_static(self.value, "__setattr__", None) - ) + ) and not isinstance(self.value, threading.local) def unpack_var_sequence(self, tx): if ( @@ -899,6 +946,18 @@ def _check_for_getattribute(self): def _check_for_getattr(self): return get_custom_getattr(self.value) + def _is_c_defined_property(self, subobj): + if not isinstance(subobj, property): + return False + + # pybind def_readwrite is implemented via PyCFunction. At the python level, it is visible as a property whose + # fget is an instancemethod wrapper - https://docs.python.org/3/c-api/method.html#c.PyInstanceMethod_Check + + # If we have a PyCFunction, we make an assumption that there is no side effect. + return isinstance( + subobj.fget, types.BuiltinFunctionType + ) or torch._C._dynamo.utils.is_instancemethod(subobj.fget) + def _getattr_static(self, name): subobj = inspect.getattr_static(self.value, name, NO_SUCH_SUBOBJ) import _collections @@ -913,12 +972,7 @@ def _getattr_static(self, name): or ( inspect.ismemberdescriptor(subobj) and name in self.value.__slots__ ) # handle memberdecriptor and slots - or ( - isinstance(subobj, property) - and isinstance( - subobj.fget, types.BuiltinFunctionType - ) # property with C-defined fget - ) + or self._is_c_defined_property(subobj) ): # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't # want to call getattr because it can be user-overridden. @@ -1168,6 +1222,54 @@ def odict_getitem(self, tx: "InstructionTranslator", key): )(collections.OrderedDict.__getitem__(self.value, key.as_python_constant())) +class FrozenDataClassVariable(UserDefinedObjectVariable): + @staticmethod + def create(tx, value, source): + from dataclasses import fields + + assert is_frozen_dataclass(value) + + from .builder import VariableBuilder + + field_map = {} + for field in fields(value): + if hasattr(value, field.name): + field_map[field.name] = VariableBuilder( + tx, AttrSource(source, field.name) + )(getattr(value, field.name)) + + return FrozenDataClassVariable(value, fields=field_map, source=source) + + def __init__(self, value, fields=None, **kwargs) -> None: + super().__init__(value, **kwargs) + if fields is None: + fields = {} + self.fields = fields + + def as_proxy(self): + from dataclasses import fields + + args = [] + kwargs = {} + for field in fields(self.value): + proxy = self.fields[field.name].as_proxy() + if hasattr(field, "kw_only") and field.kw_only: + kwargs[field.name] = proxy + else: + args.append(proxy) + + return self.python_type()(*args, **kwargs) + + # NB: This is called during __init__ for a frozen dataclass + # use this to accumulate the most up-to-date field values + def method_setattr_standard(self, tx: "InstructionTranslator", name, value): + self.fields[name.as_python_constant()] = value + return super().method_setattr_standard(tx, name, value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.value_type.__name__})" + + class SourcelessGraphModuleVariable(UserDefinedObjectVariable): def __init__( self, diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 4e9f7433f8be2c..88106676c7a7bc 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -65,9 +65,9 @@ def capture_pre_autograd_graph_warning(): log.warning("| !!! WARNING !!! |") log.warning("+============================+") log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.") - log.warning("Please switch to use torch.export._trace._export_for_training instead.") + log.warning("Please switch to use torch.export.export_for_training instead.") if config.is_fbcode(): - log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export._trace._export_for_training.") # noqa: B950 + log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950 @compatibility(is_backward_compatible=False) @@ -128,9 +128,9 @@ def capture_pre_autograd_graph( if capture_pre_autograd_graph_using_training_ir(): @lru_cache def print_export_warning(): - log.warning("Using torch.export._trace._export_for_training(...,strict=True)") + log.warning("Using torch.export.export_for_training(...,strict=True)") print_export_warning() - module = torch.export._trace._export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() + module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module() else: log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"}) @@ -184,20 +184,20 @@ def print_export_warning(): range_constraints=range_constraints, ) - error_message = \ - """ - Calling train() or eval() is not supported for exported models. - Alternatively, you may override these methods to do custom user behavior as follows: + error_message = \ + """ + Calling train() or eval() is not supported for exported models. + Alternatively, you may override these methods to do custom user behavior as follows: - def _my_train(self, mode: bool = True): - ... + def _my_train(self, mode: bool = True): + ... - def _my_eval(self): - ... + def _my_eval(self): + ... - model.train = types.MethodType(_my_train, model) - model.eval = types.MethodType(_my_eval, model) - """ + model.train = types.MethodType(_my_train, model) + model.eval = types.MethodType(_my_eval, model) + """ def _train(self, mode: bool = True): raise NotImplementedError(error_message) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 8b5ff29b6c9e02..ef15e5fea9e973 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -24,10 +24,10 @@ from torch.export.dynamic_shapes import ( _check_dynamic_shapes, _combine_args, + _DimHint, _process_dynamic_shapes, _transform_shapes_for_default_dynamic, _tree_map_with_path, - DIM, ) from torch.export.graph_signature import CustomObjArgument from torch.fx.experimental import _config as config @@ -136,10 +136,10 @@ def make_fake_inputs( combined_args = _combine_args(nn_module, args, kwargs) _check_dynamic_shapes(combined_args, dynamic_shapes) - _dynamic_shapes = _transform_shapes_for_default_dynamic( + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( combined_args, dynamic_shapes ) - constraints = _process_dynamic_shapes(combined_args, _dynamic_shapes) + constraints = _process_dynamic_shapes(combined_args, transformed_dynamic_shapes) t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict) for constraint in constraints: t_constraints[constraint.t_id][constraint.dim] = constraint @@ -216,7 +216,14 @@ def make_fake_inputs( phantom_symbols=list(phantom_symbols.values()), warn_only=False, ) - return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature + return ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + transformed_dynamic_shapes, + ) def _flatten_dynamic_shapes( @@ -351,7 +358,7 @@ def make_constraints( # we want the symbol, not its replacement, which could be an expression. Maybe # there's a better way to do this, e.g., by (re)computing value ranges for expressions? dim = shape_spec[i] if shape_spec else None - if dim is None or isinstance(dim, DIM): + if dim is None or isinstance(dim, _DimHint): range_constraints[d.node.expr] = shape_env.var_to_range[ d.node._expr ] @@ -499,7 +506,11 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} - if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc: + if ( + not torch.compiler.is_dynamo_compiling() + and log.isEnabledFor(logging.DEBUG) + and config.extended_debug_current_loc + ): frame = _find_user_code_frame() if frame is not None: log.debug( diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index b1491ca5d47946..083cd69a970df6 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -195,7 +195,7 @@ def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: self.node_replacements[node] = tensor - def run(self): + def run(self): # type: ignore[override] env = {} for n in self.module.graph.find_nodes(op="placeholder"): env[n] = self.unknown_value diff --git a/torch/_export/serde/dynamic_shapes.py b/torch/_export/serde/dynamic_shapes.py new file mode 100644 index 00000000000000..f24822d9b07d68 --- /dev/null +++ b/torch/_export/serde/dynamic_shapes.py @@ -0,0 +1,321 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch._dynamo.exc import UserError, UserErrorType +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _DerivedDim, + _Dim, + _DimHint, + _tree_map_with_path, + Dim, +) +from torch.utils._pytree import tree_map + +from .serialize import _dataclass_to_dict + + +@dataclasses.dataclass +class RootDim: + """ + This represents a _Dim object. + """ + + min: int + max: Union[int, None] + derived: List[str] + + +@dataclasses.dataclass +class DynamicShapesSpec: + """ + This stores a dynamic_shapes spec for de/serialization. + """ + + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] + dims: Dict[str, RootDim] + + +def _postprocess_serialized_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + dims: Dict[str, Dict[str, Union[int, List[str], None]]], + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, Dict[str, Any]]: + """ + Sorts dims and dumps to dictionary format. + """ + from torch.utils._sympy.numbers import int_oo + + dims = { + k: RootDim( + min=v["min"], # type: ignore[arg-type] + max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type] + derived=sorted(v["derived"]), # type: ignore[arg-type] + ) + for k, v in sorted(dims.items()) + } + spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims) + if to_dict: + return _dataclass_to_dict(spec) + else: + return spec + + +def _dump_dynamic_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + to_dict: Optional[bool] = False, +) -> Union[DynamicShapesSpec, Dict[str, Any]]: + """ + Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec. + Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims". + Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones). + + dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export(): + - Each tensor input is represented with a list of values, non-tensor inputs with None. + - dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings. + - static dimensions are represented with ints. + + dims: A dictionary mapping each symbol name to the min/max range and derived dim names. + + For example: + ``` + dx = Dim("dx", min=4, max=16) + dy = dx + 1 + + inputs = ( + [ + torch.randn(4, 4), + torch.randn(5, 4), + ], + torch.randn(4), + torch.randn(4, 4), + "hello", + ) + dynamic_shapes = { + "a": [ + (dx, 4), + (dy, 4), + ], + "b": (Dim.STATIC,), + "c": None, + "d": None, + } + out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True) + ``` + would generate the following output: + ``` + { + 'dynamic_shapes': ( + [ + ['dx', 4], + ['dx + 1', 4], + ], + ['_DimHint.STATIC'], + ['_DimHint.STATIC', '_DimHint.STATIC'], + None, + ), + 'dims': { + 'dx': { + 'min': 4, + 'max': 16, + 'derived': ['dx + 1'], + }, + }, + } + ``` + """ + dims: Dict[str, Dict[str, Any]] = {} + + def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def] + """ + Helps standardize the dynamic_shapes tree structure we serialize, + returning lists for each tensor shape, handling tensor-level Nones. + """ + if not isinstance(tensor, torch.Tensor): + return None + if shape is None: + return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined] + + out = [] + if isinstance(shape, dict): + for i, s in enumerate(tensor.shape): + out.append(s if shape.get(i) is None else shape.get(i)) + else: + assert isinstance(shape, (tuple, list)) + for i, s in enumerate(tensor.shape): + out.append(s if shape[i] is None else shape[i]) + return out + + def _track_dim_from_dims( + val: Union[None, int, _DimHint, _Dim] + ) -> Union[None, int, str]: + """ + Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec. + """ + if val is None or isinstance(val, int): # non-tensor input or static + return val + if isinstance(val, _DimHint): # store enum as string + return val.__class__.__name__ + "." + val.name + + assert isinstance(val, _Dim) + + # track root dim + root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined] + if root.__name__ not in dims: + dims[root.__name__] = { + "min": root.min, + "max": root.max, + "derived": set(), + } + + # track derived dims + if isinstance(val, _DerivedDim): + dims[root.__name__]["derived"].add(val.__name__) + + return val.__name__ + + if dynamic_shapes is None: + return {"dynamic_shapes": None, "dims": {}} + + # convert to tuple of specs, for each arg/kwarg + kwargs = kwargs or {} + if isinstance(dynamic_shapes, dict): + dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment] + dynamic_shapes = tuple(dynamic_shapes) + combined_args = tuple(args) + tuple(kwargs.values()) + + # run same check when we're processing shapes for export - is this too lazy? + _check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type] + + tree_shapes = _tree_map_with_path( + _standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs" + ) + serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes) + return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict) + + +def _load_dynamic_shapes( + spec: Union[DynamicShapesSpec, Dict[str, Any]], + from_dict: Optional[bool] = False, +) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: + """ + Utility function for dynamic shapes serialization. + Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export(). + """ + import sympy + + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + if from_dict: + if not isinstance(spec, dict): + raise UserError( + UserErrorType.INVALID_INPUT, + f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}", + ) + if sorted(spec.keys()) != ["dims", "dynamic_shapes"]: + raise UserError( + UserErrorType.INVALID_INPUT, + "With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, " + f"instead found {spec.keys()}", + ) + dims = {} + for k, v in spec["dims"].items(): + if not isinstance(k, str): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}", + ) + if sorted(v.keys()) != ["derived", "max", "min"]: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, " + f"instead found {v.keys()}", + ) + if not isinstance(v["min"], int): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}", + ) + if not isinstance(v["max"], int) or v["max"] is None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}", + ) + if not isinstance(v["derived"], list) or any( + not isinstance(d, str) for d in v["derived"] + ): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, " + f"got {k}: {v['derived']}", + ) + dims[k] = RootDim(**v) + dynamic_shapes = spec["dynamic_shapes"] + else: + if not isinstance(spec, DynamicShapesSpec): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}", + ) + dims = spec.dims + dynamic_shapes = spec.dynamic_shapes + + if dynamic_shapes is None: + return None + + dim_cache = {} + for name, info in dims.items(): + symbol = sympy.sympify(name) + if not isinstance(symbol, sympy.Symbol): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected `spec['dims']` keys to be symbols, got {name}", + ) + dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim + for _expr in info.derived: + expr = sympy.sympify(_expr) + if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions in to have {name} as the only free symbol, got {expr}", + ) + if not _is_supported_equivalence(expr): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected derived expressions to be linear expressions, got {expr}", + ) + modulus, remainder = sympy.polys.polytools.div(expr, symbol) + ddim = dim_cache[name] + if modulus != 1: + ddim = int(modulus) * ddim + if remainder != 0: + ddim = ddim + int(remainder) + dim_cache[_expr] = ddim # cache derived dims + + def deserialize_shape( + val: Union[None, int, str] + ) -> Union[None, int, _Dim, _DimHint]: + if val is None or isinstance(val, int): + return val + elif val == "_DimHint.AUTO": + return _DimHint.AUTO + elif val == "_DimHint.STATIC": + return _DimHint.STATIC + if not isinstance(val, str): + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, " + f" or derived expressions, got {val}", + ) + if val not in dim_cache: + raise UserError( + UserErrorType.INVALID_INPUT, + "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " + f"got {val} which is not in {dims.keys()}", + ) + return dim_cache[val] + + return tree_map(deserialize_shape, dynamic_shapes) diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index c12997e049c1de..ce102b39367ad0 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -8,7 +8,7 @@ from torch._export.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (7, 2) +SCHEMA_VERSION = (7, 3) TREESPEC_VERSION = 1 @@ -378,3 +378,4 @@ class ExportedProgram: range_constraints: Dict[str, RangeConstraint] schema_version: SchemaVersion verifiers: List[str] = field(default_factory=list) + torch_version: str = "<=2.4" diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index f8eda2b5c3b02c..25a9a295ad0b97 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<8c1eb1b5b62b21e3725448a11d3a807cfe006e7558870a2c1ccc455d3def71c4>> +# checksum<<923abf371a1f8802cacb037d409d28273867777a98f6542fba28616c2b92b639>> Argument: kind: union fields: @@ -105,6 +105,9 @@ ExportedProgram: verifiers: type: List[str] default: '[]' + torch_version: + type: str + default: <=2.4 GradientToParameterSpec: kind: struct fields: @@ -430,5 +433,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 7 -- 2 +- 3 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index f77d081d5aa85c..33a08032479e33 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -7,12 +7,15 @@ import inspect import io import json +import keyword import logging import math import operator +import re import typing import traceback +from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -578,23 +581,21 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: ret["stack_trace"] = stack_trace if nn_module_stack := node.meta.get("nn_module_stack"): - keys, paths, tys = [], [], [] - for k, v in nn_module_stack.items(): - keys.append(k) - assert isinstance(v, tuple) and len(v) == 2 - path, ty = v + + def export_nn_module_stack(val): + assert isinstance(val, tuple) and len(val) == 2 + path, ty = val + assert isinstance(path, str) assert isinstance(ty, str) - paths.append(path) - tys.append(ty) - nn_module_stack_dict = { - "key_list": keys, - "path_list": paths, - "type_list": tys - } + return path + "," + ty - ret["nn_module_stack"] = json.dumps(nn_module_stack_dict) + # Serialize to "key,orig_path,type_str" + nn_module_list = [ + f"{k},{export_nn_module_stack(v)}" for k, v in nn_module_stack.items() + ] + ret["nn_module_stack"] = ST_DELIMITER.join(nn_module_list) if source_fn_st := node.meta.get("source_fn_stack"): source_fn_list = [ @@ -1395,6 +1396,7 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: minor=SCHEMA_VERSION[1], ), verifiers=[v.dialect for v in exported_program.verifiers], + torch_version=torch.__version__, ) # Test canonical form is well defined. @@ -1946,13 +1948,19 @@ def deserialize_inputs(self, target, serialized_node: Node): for input in serialized_node.inputs } args = [] - kwargs = {} + kwargs: OrderedDict[str, Any] = OrderedDict() for schema_arg in schema_args: is_positional = ( not schema_arg.has_default_value() and not schema_arg.kwarg_only ) if is_positional: args.append(actual_args[schema_arg.name]) + elif keyword.iskeyword(schema_arg.name): + assert not schema_arg.kwarg_only + if len(kwargs) > 0: + kwargs = OrderedDict() + args.extend(list(kwargs.values())) + args.append(actual_args[schema_arg.name]) else: if schema_arg.name in actual_args: kwargs[schema_arg.name] = actual_args[schema_arg.name] @@ -2189,13 +2197,25 @@ def deserialize_meta_func(serialized_target: str): if nn_module_stack_str := metadata.get("nn_module_stack"): # Originally serialized to "key,orig_path,type_str" - nn_module_stack_dict = json.loads(nn_module_stack_str) - nn_module_stack = { - key: (path, ty) - for key, path, ty in zip( - nn_module_stack_dict["key_list"], - nn_module_stack_dict["path_list"], - nn_module_stack_dict["type_list"])} + def import_nn_module_stack(key, path, ty): + return key, (path, ty) + + # Helper function that splits strings by commas except for those + # encapsulated by parens, which are valid traces. + # TODO: Currently this is needed due to indexing Sequential + # layers introducing names in the form "layer.slice(1, None, None)". + # If that naming is improved, this fancier splitting can probably be + # reverted to a simple split by comma. + def metadata_split(metadata): + # Remove the parentheses and commas inside them + metadata = re.sub(r'\(.*?\)', '', metadata) + # Split the string by comma, except for those inside parentheses + return re.split(r'(? "ConstantAttrMap": + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. This is intended to only be used + # in run_decompositions where we still have access to original EP. + from torch._export.passes.lift_constants_pass import ConstantAttrMap + + constant_attrs = ConstantAttrMap() + non_persistent_buffers = { + spec.target + for spec in graph_signature.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + for name, value in constants.items(): + if name in non_persistent_buffers: + continue + # recursive getattr + _mod = mod + *atoms, attr = name.split(".") + for atom in atoms: + _mod = getattr(_mod, atom) + # remove as buffer, reassign as constant/non-persistent buffer + _mod._buffers.pop(attr, None) + setattr(_mod, attr, value) + constant_attrs.add(value, name) + return constant_attrs + + +def _overwrite_signature_for_non_persistent_buffers( + old_sig: "ExportGraphSignature", new_sig: "ExportGraphSignature" +): + # overwrite signature for non-persistent buffers + non_persistent_buffers = { + spec.target + for spec in old_sig.input_specs + if spec.kind == InputKind.BUFFER and not spec.persistent + } + + for spec in new_sig.input_specs: + if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: + spec.persistent = False + return new_sig + + +def _collect_param_buffer_metadata(mod: torch.fx.GraphModule) -> Dict[str, Any]: + """ + Param/buffer metadata needs to be saved before lowering to aten IR + because aten IR lifts them, as a result, automatic preservation doesn't work. + This is intended to be called on the strict mode tracing right before lowering to + aten IR OR run_decomposition pass. + """ + params_buffers_to_node_meta = {} + + def _getattr(model: torch.fx.GraphModule, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + for node in mod.graph.nodes: + target = node.target + meta = node.meta + if node.op == "call_module": + submodule = _getattr(mod, target) + if isinstance(submodule, torch.nn.Module): + for name, _ in submodule.named_parameters( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + for name, _ in submodule.named_buffers( + recurse=True, remove_duplicate=False + ): + params_buffers_to_node_meta[target + "." + name] = meta + + if node.op == "get_attr": + submodule = _getattr(mod, target) + if not isinstance(submodule, torch.fx.GraphModule): + params_buffers_to_node_meta[target] = meta + + # If the call_function uses param as input, we also need to update params' meta + # with this call_function node's meta. + # This is basically the same flow as torch.fx.traceback.preserve_meta() + if node.op == "call_function" and not isinstance( + node.target, torch._ops.HigherOrderOperator + ): + for arg in node._input_nodes: + if arg.op == "get_attr": + for entry in torch.fx.proxy._COPY_META_FIELDS: + # the custom field should not be copied + if entry == "custom": + continue + if entry in meta: + params_buffers_to_node_meta[arg.target][entry] = meta[entry] + + return params_buffers_to_node_meta + + +def _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta: Dict[str, Any], + gm: torch.fx.GraphModule, + new_sig: "ExportGraphSignature", +) -> None: + """ + Given that we collected param'buffer metadata before, we put them back in + newly traced graph module + """ + # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes + for metadata in params_buffers_to_node_meta.values(): + metadata.pop("nn_module_stack", None) + metadata.pop("stack_trace", None) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node.target in new_sig.inputs_to_parameters: + param_name = new_sig.inputs_to_parameters[node.target] + if param_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[param_name].items(): + node.meta[k] = v + if node.target in new_sig.inputs_to_buffers: + buffer_name = new_sig.inputs_to_buffers[node.target] + if buffer_name in params_buffers_to_node_meta: + for k, v in params_buffers_to_node_meta[buffer_name].items(): + node.meta[k] = v + + +def _get_shape_env_from_gm(gm: torch.fx.GraphModule): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + + fake_mode = _detect_fake_mode_from_gm(gm) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _rename_without_collisions( + name_map: Dict[str, str], + orig_name: str, + name: str, + is_placeholder: bool = False, +): + """ + Renames nodes to avoid name collisions, with suffixing. + name_map: map from original name to new name + orig_name: mapping key + name: candidate name (potentially suffixed, e.g. mul_2) + is_placeholder: if the node is a placeholder, avoid detecting suffix + """ + if name in name_map.values(): + # non-placeholder nodes may be suffixed with the count + # instead of adding another suffix, we will try to increment it + match = re.match(r"(.*)_(\d+)", name) + if match and not is_placeholder: + name, n = match.group(1), int(match.group(2)) + else: + n = 0 + while (dup_name := f"{name}_{n + 1}") in name_map.values(): + n += 1 + name_map[orig_name] = dup_name + else: + name_map[orig_name] = name + return name_map[orig_name] + + def _check_input_constraints_for_graph( input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints ): @@ -227,7 +405,7 @@ def default_flatten_fn_with_keys(obj: Any) -> Tuple[List[Any], Context]: ) -def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: +def is_param(program: "ExportedProgram", node: torch.fx.Node) -> bool: """ Checks if the given node is a parameter within the exported program """ @@ -236,7 +414,7 @@ def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: def get_param( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> Optional[torch.nn.Parameter]: """ @@ -251,7 +429,7 @@ def get_param( return None -def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: +def is_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: """ Checks if the given node is a buffer within the exported program """ @@ -260,7 +438,7 @@ def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: def get_buffer( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> Optional[torch.Tensor]: """ @@ -279,7 +457,7 @@ def get_buffer( def is_lifted_tensor_constant( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> bool: """ @@ -290,7 +468,7 @@ def is_lifted_tensor_constant( def get_lifted_tensor_constant( - program: ExportedProgram, + program: "ExportedProgram", node: torch.fx.Node, ) -> Optional[torch.Tensor]: """ @@ -484,9 +662,51 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): return sig.bind(*fake_args, **fake_kwargs).arguments +def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + """ + Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, + and handle collisions with non-placeholders by count suffixing. + Different HOO subgraph types have different input schemas, so we first enumerate them + and gather the top-level named placeholder nodes. + """ + # gather all HOO subgraphs and their top-level named placeholder nodes + subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] + for node in gm.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + # HOO subgraphs have varying input schemas, so we enumerate them there + if node.target._name == "cond": + _, true_graph, false_graph, cond_args = node._args + subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) + subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) + elif node.target._name == "wrap_with_set_grad_enabled": + subgraph, phs = node._args[1], node._args[2:] + subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) + elif node.target._name == "map_impl": + body_graph, array, args = node._args + subgraph_ph_tuples.append( + (getattr(gm, body_graph.target), array + args) + ) + + # propagate names + for subgraph, hoo_phs in subgraph_ph_tuples: + name_map: Dict[str, str] = {} + for i, node in enumerate(subgraph.graph.nodes): + if i < len(hoo_phs): # placeholder, retain name + name_map[node.name] = hoo_phs[i].name + node.name = node.target = hoo_phs[i].name + else: # non-placeholder, check for collisions + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # recurse and recompile + _name_hoo_subgraph_placeholders(subgraph) + subgraph.recompile() + + def placeholder_naming_pass( gm: torch.fx.GraphModule, - export_graph_signature: torch.export.ExportGraphSignature, + export_graph_signature: "ExportGraphSignature", mod: torch.nn.Module, fake_args, fake_kwargs, @@ -514,6 +734,8 @@ def placeholder_naming_pass( def _strip_name(x): if x.startswith("L__self___"): x = x[len("L__self___") :] + elif x.startswith("self_"): + x = x[len("self_") :] x = re.sub(r"[^a-zA-Z0-9]", "_", x) return x @@ -652,7 +874,6 @@ def _detect_fake_mode_from_gm( Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. If no fake mode is found, we return None for fake_mode. """ - from torch._guards import detect_fake_mode fake_inps: List[torch.Tensor] = [] fake_vals: List[torch.Tensor] = [] diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index b272621f8c94b8..68c5bcaae39af6 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -153,6 +153,7 @@ def check_additional(self, gm: GraphModule) -> None: @final def check(self, ep: "ExportedProgram") -> None: self._check_graph_module(ep.graph_module) + _verify_exported_program_module_call_graph(ep) _verify_exported_program_signature(ep) @final @@ -271,6 +272,25 @@ class TrainingIRVerifier(Verifier): dialect = "TRAINING" +def _verify_exported_program_module_call_graph(exported_program) -> None: + module_call_graph = exported_program.module_call_graph + nodes = { + node.name for node in exported_program.graph.nodes + } + for entry in module_call_graph: + if entry.signature is not None: + for arg in entry.signature.inputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Input {arg.name} does not exist in the graph." + ) + for arg in entry.signature.outputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Output {arg.name} does not exist in the graph." + ) + + def _verify_exported_program_signature(exported_program) -> None: # Check ExportedProgram signature matches gs = exported_program.graph_signature diff --git a/torch/_export/wrappers.py b/torch/_export/wrappers.py index 5aee2f16a0ddcf..d57ff46de41c8f 100644 --- a/torch/_export/wrappers.py +++ b/torch/_export/wrappers.py @@ -16,6 +16,9 @@ class ExportTracepoint(HigherOrderOperator): def __init__(self): super().__init__("_export_tracepoint") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + _export_tracepoint = ExportTracepoint() diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index d280f67d2fd516..17b619519c158d 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -256,13 +256,13 @@ def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGra # [Note: AOTAutogradCache and FXGraphCache Guard interactions] # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. - # he invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly # the same as the ones it passes to inductor, for both the forward and backward passes. # (This does not mean that the tensor values passed in are the same: only that their symints are). # That is, AOTAutograd and Inductor never create new guards based on symints with different sources # than those passed to it by inductor. result = FxGraphCache._lookup_graph( - self.fx_graph_cache_key, example_inputs, local=True, remote_cache=False + self.fx_graph_cache_key, example_inputs, local=True, remote_cache=None ) if result is None: log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_key) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 820c7d288cf2c6..51a0aeb24ad496 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -9,6 +9,7 @@ """ import collections +import contextlib import logging from functools import wraps from typing import Callable, DefaultDict, Dict, List, Optional @@ -162,15 +163,18 @@ def inner(*flat_args): # It doesn't matter if we run this under predispatch or not because it is # only for figuring out metadata mode = FunctionalTensorMode(_allow_token_discovery=True) - with disable_above, mode: + suppress_pending = contextlib.nullcontext() + fake_mode = detect_fake_mode() + if fake_mode and (shape_env := fake_mode.shape_env): + suppress_pending = shape_env.ignore_fresh_unbacked_symbols() + with disable_above, mode, suppress_pending: # precondition: The passed in function already handles unflattening inputs + flattening outputs flat_f_args = pytree.tree_map(_to_fun, flat_args) flat_f_outs = f(*flat_f_args) # We didn't do any tracing, so we don't need to process the # unbacked symbols, they will just disappear into the ether. # Also, prevent memoization from applying. - if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env): - shape_env.pending_fresh_unbacked_symbols.clear() + if fake_mode: fake_mode.epoch += 1 fake_mode.reset_nt_tensor_id_counter() diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 2013eca1f6fb86..a12d42db7475e0 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -184,7 +184,6 @@ def _map_assigned_buffer_to_proxy(_mod, name, buffer): # As long as we opted to remove input mutations, then # there should be *NO* mutating ops in the graph at this point. copy_count = assert_functional_graph(fw_module.graph) - fw_module.graph.eliminate_dead_code() fw_module.recompile() diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 48e4ba84cf3c02..71862997ae0717 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -405,8 +405,8 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: torch.ops.aten.copy_.default, torch.ops.aten.set_.source_Tensor, ] - if hasattr(torch.ops.fsdp, "set_"): - allowed_mutation_ops.append(torch.ops.fsdp.set_.default) + if hasattr(torch.ops.fsdp, "copy_"): + allowed_mutation_ops.append(torch.ops.fsdp.copy_.default) placeholders = set() mutation_count = 0 diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index f01e83d37b1c50..5dc236f314b079 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -512,7 +512,11 @@ def aot_dispatch_autograd( == MutationType.MUTATED_IN_GRAPH and fw_metadata.input_info[i].mutates_storage_metadata ) - if bw_out is None and not metadata_mutation_in_graph: + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: _indices_of_inps_to_detach.append(i) if aot_config.enable_log: diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 270c1895f6fd54..cb501e2c924213 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -719,6 +719,7 @@ def __init__(self) -> None: def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs): saved_values = None args_tensor_mask = fwd_kwargs["args_tensor_mask"] + non_differentiable_idx = fwd_kwargs["non_differentiable_idx"] length_of_tensor_args = sum(args_tensor_mask) # Filter out the original tensor args from fwd_args, # lifted freevars should not be args of ApplyTemplate.apply @@ -730,6 +731,15 @@ class ApplyTemplate(torch.autograd.Function): def forward(ctx, *args): nonlocal saved_values output, saved_values = fwd(None, *fwd_args) + + # If users call ctx.mark_non_differentiable() in the original fwd function. + if len(non_differentiable_idx) > 0: + non_differentiable_output = [] + for i, x in enumerate(output): + if i in non_differentiable_idx: + non_differentiable_output.append(x) + ctx.mark_non_differentiable(*non_differentiable_output) + return output @staticmethod diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 77aeeeb2b7f1f1..3bf61f1af3bf3c 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -5,6 +5,7 @@ import torch import torch.fx as fx +from torch.multiprocessing.reductions import StorageWeakRef from torch.utils import _pytree as pytree from torch.utils._pytree import tree_flatten @@ -50,6 +51,37 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph): ) compute_mutation_region_ids(fx_g) # type: ignore[arg-type] + + # Make a set of separate storages returned from the output, which will be preserved + # when pruning. This prevents us from deduplicating returned tensors which have + # experienced identical operations, but are separate data structures in eager mode. + output_node: fx.Node = list(fx_g.nodes)[-1] + assert output_node.op == "output" + + def checkable_node(node: fx.Node) -> bool: + """We can evaluate only nodes that represent tensors with defined storage.""" + if "val" not in node.meta or not isinstance(node.meta["val"], torch.Tensor): + return False + + try: + node.meta["val"].untyped_storage() + except NotImplementedError: + return False + + return True + + output_storages = { + StorageWeakRef(n.meta["val"].untyped_storage()) + for n in output_node.all_input_nodes + if checkable_node(n) + } + nodes_that_alias_outputs = { + n + for n in fx_g.nodes + if checkable_node(n) + and StorageWeakRef(n.meta["val"].untyped_storage()) in output_storages + } + for n in fx_g.nodes: # The placeholder, output, and get_attr nodes are copied to the new graph without change # do not CSE away random operations @@ -58,6 +90,11 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph): or n.op == "output" or n.op == "get_attr" or get_aten_target(n) in rand_ops + # aten.empty is non-deterministic, so don't CSE it. + # Also, aten.empty is almost always fusible into its consumer, + # so it's not worth CSEing. + or get_aten_target(n) is aten.empty + or n in nodes_that_alias_outputs ): new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 86c63be17fc9d0..7ecc4a2d7449ff 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -12,7 +12,7 @@ def functional_call( module: "torch.nn.Module", parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]], - args: Union[Any, Tuple], + args: Optional[Union[Any, Tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 8716f792a4294f..81e2f297f6fd2c 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -784,6 +784,26 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] ): node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + if node.meta.get("has_backward_hook", False) and not any( + must_recompute(user) for user in node.users + ): + # If node is AC region output and has a backward hook on it, we intentionally choose to save it. + # This is to work around circular dependencies in Traceable FSDP2+AC. + # Example: + # ``` + # out = fully_shard(utils.checkpoint(module))(x) + # norm_out = layer_norm(out) + # ``` + # Here there is a circular dependency: + # 1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`. + # 2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights) + # in order to be recomputed. + # 3. `out`'s backward hook, as is the case for all eager backward hooks, depends on `out_grad` + # -> circular dependency with (1)! + # + # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` + # in forward graph outputs. With this, we can break the above circular dependency. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module @@ -804,14 +824,45 @@ def solve_min_cut( if node.op == "call_function" and hasattr(node.target, "_overloadpacket") } ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} - print("Ops banned from rematerialization: ", ops_ignored) + print("Ops banned from re-materialization: ", ops_ignored) print() + def can_fuse_into_auto_functionalized(a, b): + if b.target != torch.ops.higher_order.auto_functionalized: + return False + mutable_op = b.args[0] + ( + mutable_arg_names, + _, + ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op) + for name in mutable_arg_names: + arg = b.kwargs[name] + if a is arg: + return True + if isinstance(arg, list): + if a in arg: + return True + return False + + def can_fuse_into_triton_kernel_wrapper_functional(a, b): + if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional: + return False + mutable_arg_names = b.kwargs["tensors_to_clone"] + for name in mutable_arg_names: + arg = b.kwargs["kwargs"][name] + if a is arg: + return True + return False + def is_fusible(a, b): # We can perform "memory fusion" into a cat, but cat cannot be a # producer to a fusion if get_aten_target(b) == aten.cat: return True + if can_fuse_into_auto_functionalized(a, b): + return True + if can_fuse_into_triton_kernel_wrapper_functional(a, b): + return True return op_types.is_fusible(a) and op_types.is_fusible(b) try: @@ -1241,6 +1292,7 @@ def get_default_op_list() -> OpTypes: aten.expand, aten.as_strided, aten.permute, + aten.select, ] view_ops = recomputable_view_ops default_recomputable_ops += [ @@ -1273,6 +1325,8 @@ def get_default_op_list() -> OpTypes: aten.full, aten.as_strided, aten.zeros, + aten.empty, + aten.empty_like, aten.argmax, aten.maximum, prims.iota, diff --git a/torch/_guards.py b/torch/_guards.py index ccc07735815bf4..012f26c5bb3ba3 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -26,6 +26,7 @@ TypeVar, ) +from torch._C._dynamo.eval_frame import set_context_frame # noqa: F401 from torch.utils import _pytree as pytree from torch.utils._traceback import CapturedTraceback from torch.utils.weak import WeakTensorKeyDictionary @@ -781,6 +782,15 @@ def compile_context(context: Optional[CompileContext]): try: yield context finally: + if context is not None: + if context.compile_id is not None: + set_context_frame( + ( + context.compile_id.frame_id, + context.compile_id.frame_compile_id, + context.attempt, + ) + ) _TLS.compile_context = old_context diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 5bfca7aa679583..72800cae7fc98c 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -3,6 +3,7 @@ flex_attention, flex_attention_backward, ) +from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.while_loop import while_loop @@ -11,4 +12,5 @@ "while_loop", "flex_attention", "flex_attention_backward", + "hints_wrapper", ] diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py index ef57af9c0d63b4..d58d6b26bd33f7 100644 --- a/torch/_higher_order_ops/associative_scan.py +++ b/torch/_higher_order_ops/associative_scan.py @@ -8,13 +8,14 @@ import torch._subclasses.functional_tensor import torch.utils._pytree as pytree from torch._C import DispatchKey -from torch._C._functorch import _add_batch_dim, get_unwrapped, maybe_get_bdim from torch._higher_order_ops.utils import ( + _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, reenter_make_fx, unique_graph_id, ) +from torch._inductor.utils import is_pointwise_use from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( @@ -37,12 +38,43 @@ def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves): return combined_leaves +def _interleave(a, b, dim): + # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors + if b_trunc := (a.shape[dim] == b.shape[dim] + 1): + pad = ( + [0] * ((b.ndim - dim - 1) * 2 + 1) + + [1] + + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2)) + ) + b = torch.nn.functional.pad(b, pad) + + stacked = torch.stack([a, b], dim=dim + 1) + interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1) + if b_trunc: + # TODO: find torch alternative for slice_along dim for torch.jit.script to work + interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1) + return interleaved + + +def safe_map(f, *args): + args = list(map(list, args)) + n = len(args[0]) + for arg in args[1:]: + if len(arg) != n: + raise ValueError("length mismatch: {list(map(len, args))}") + + def nf(a): + return f(*a) + + return list(map(nf, zip(*args))) + + class AssociativeScanOp(HigherOrderOperator): def __init__(self): super().__init__("associative_scan") - def __call__(self, combine_fn, input, dim): - return super().__call__(combine_fn, input, dim) + def __call__(self, combine_fn, xs, dim): + return super().__call__(combine_fn, xs, dim) associative_scan_op = AssociativeScanOp() @@ -50,12 +82,13 @@ def __call__(self, combine_fn, input, dim): def associative_scan( combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree], - input: pytree.PyTree, + xs: pytree.PyTree, dim: int, reverse: bool = False, + combine_mode: str = "pointwise", ) -> torch.Tensor: r""" - Performs an inclusive scan with an associative pointwise combine function. + Performs an inclusive scan with an associative combine function. .. warning:: `torch.associative_scan` is a prototype feature in PyTorch. It currently @@ -69,11 +102,17 @@ def associative_scan( Args: combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, or if input is a pytree ``(pytree, pytree) -> pytree``. - This function must be pure, pointwise, and satisfy the associative property. - input (torch.Tensor): The input tensor, or nested pytree of tensors. + This function must be pure, i.e., no lifted arguments are supported at the moment, + satisfy the associative property and have no side-effects. + xs (torch.Tensor): The input tensor, or nested pytree of tensors. All inputs are expected to have the same shape. dim (int): the dimension to scan over - reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension. + reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``. + combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``, default ``pointwise``. + If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations + and ``xs`` must be CUDA tensors. + In all other cases ``combine_mode=generic`` should be used. + Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``. Example:: @@ -84,36 +123,65 @@ def add(x: torch.Tensor, y: torch.Tensor): cumsum = associative_scan(add, x, dim) """ - assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}" - assert isinstance(dim, int), "dim must be an int, but got {type(dim)}" + if not callable(combine_fn): + raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}") + if not isinstance(dim, int): + raise RuntimeError("Dim must be an int, but got " + str(type(dim))) + if combine_mode not in ["pointwise", "generic"]: + raise RuntimeError( + "Combine_mode must either 'pointwise' or 'generic', but got {combine_mode}" + ) if not torch._dynamo.is_compiling(): with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): return torch.compile(associative_scan, fullgraph=True)( - combine_fn, input, dim, reverse=reverse + combine_fn, xs, dim, reverse=reverse, combine_mode=combine_mode ) - leaves, spec = pytree.tree_flatten(input) + leaves, spec = pytree.tree_flatten(xs) + + if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves): + raise ValueError( + "For combine_mode='pointwise', all input tensors need to be on CUDA" + ) + + if len(leaves) == 0: + raise RuntimeError("Expected at least 1 xs leaf") + if any(not isinstance(x, torch.Tensor) for x in leaves): + raise RuntimeError("xs leaves must be a Tensor") if reverse: leaves = [torch.flip(elem, [dim]) for elem in leaves] - assert len(leaves) >= 1, "expected at least 1 input leaf" - assert all( - isinstance(x, torch.Tensor) for x in leaves - ), "input leaves must be a Tensor" shape = leaves[0].shape ndim = len(shape) dim = utils.canonicalize_dim(ndim, dim) for x in leaves[1:]: - assert x.shape == shape, "All input tensors must have the same shape" + assert x.shape == shape, "All xs tensors must have the same shape" + + out = combine_fn( + pytree.tree_unflatten(leaves, spec), + pytree.tree_unflatten(leaves, spec), + ) + out_leaves, tree_out = pytree.tree_flatten(out) + if len(leaves) != len(out_leaves): + raise RuntimeError( + "The number of leaves of the pytree of the output of the operator needs to match the length of the pytree of the input" + ) + if any(x.shape != shape for x in out_leaves): + raise RuntimeError( + "The pytree of the output of the operator needs to match the xs pytree" + ) combine_fn = functools.partial( wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves) ) - result_flat = associative_scan_op(combine_fn, leaves, dim) + if combine_mode == "generic": + result_flat = generic_associative_scan(combine_fn, leaves, dim) + else: + result_flat = associative_scan_op(combine_fn, leaves, dim) if reverse: result_flat = [torch.flip(elem, [dim]) for elem in result_flat] @@ -121,15 +189,116 @@ def add(x: torch.Tensor, y: torch.Tensor): return pytree.tree_unflatten(result_flat, spec) +def generic_associative_scan(operator, elems_flat, dim=0): + r""" + This function performs the associative_scan operation. + The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently + applying the ``operator`` on all pairs in parallel along ``dim``. + The results of the recursive calls are later combined. + + Args: + operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``, + or if input is a pytree ``(pytree, pytree) -> pytree``. + This function must be pure, pointwise, and satisfy the associative property. + elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of + ``xs`` provided to ``associative_scan``. + All inputs are expected to have the same shape. + dim (int): the dimension to scan over + + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + return x + y + + elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0]) + + First iteration of _scan -> + # odd_elems -> apply operator on all neighbours + # odd_elems = operator([torch.tensor([0.0, 2.0])], + # [torch.tensor([1.0, 3.0])]) + odd_elems = torch.tensor([1.0, 5.0]) + Second iteration of _scan -> + # odd_elems = operator([torch.tensor([1.0])], + # [torch.tensor([5.0])]) + odd_elems = torch.tensor([6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [1.0] + # Merges odd_elems and even_elems + res = torch.tensor([1.0, 6.0]) + # even_elems -> apply operator on all odd_elems and + # every second element of ``elems``, starting from the second element. + # even_elems is expanded with the first element of ``elems`` + even_elems = [0.0, 3.0] + # Merges odd_elems and even_elems + res = torch.tensor([0.0, 1.0, 3.0, 6.0]) + + """ + + def _scan(elems): + """Perform the actual recursive scan on ``elems``.""" + num_elems = elems[0].shape[dim] + + if num_elems < 2: + return elems + + reduced_elems = operator( + *[aten.slice(elem, dim, 0, -1, 2) for elem in elems], + *[aten.slice(elem, dim, 1, None, 2) for elem in elems], + ) + + # Recursively compute scan for partially reduced tensors. + odd_elems = _scan(reduced_elems) + + if num_elems % 2 == 0: + even_elems = operator( + *[aten.slice(e, dim, 0, -1) for e in odd_elems], + *[aten.slice(e, dim, 2, None, 2) for e in elems], + ) + else: + even_elems = operator( + *odd_elems, + *[aten.slice(e, dim, 2, None, 2) for e in elems], + ) + + # The first element of a scan is the same as the first element + # of the original `elems`. + even_elems = [ + torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim) + if result.shape.numel() > 0 and elem.shape[dim] > 0 + else result + if result.shape.numel() > 0 + else aten.slice( + elem, dim, 0, 1 + ) # Jax allows/ignores concat with 0-dim, Pytorch does not + for (elem, result) in zip(elems, even_elems) + ] + + return list( + safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems) + ) + + scans = _scan(elems_flat) + + return scans + + def trace_associative_scan( - proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int + proxy_mode, func_overload, combine_fn: Callable, xs: List[torch.Tensor], dim: int ): with disable_proxy_modes_tracing(): - sample_inputs = [ - torch.full((), False, dtype=x.dtype, device=x.device) - for x in itertools.chain(input, input) + sample_xs = [ + torch.empty_like( + x, + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + for x in itertools.chain(xs, xs) ] - combine_graph = reenter_make_fx(combine_fn)(*sample_inputs) + combine_graph = reenter_make_fx(combine_fn)(*sample_xs) outputs = None for node in combine_graph.graph.nodes: @@ -138,42 +307,41 @@ def trace_associative_scan( assert len(node.args) == 1 outputs = node.args[0] + if not all(is_pointwise_use(use) or use.op == "output" for use in node.users): + raise ValueError( + "For combine_mode='pointwise', the combine_fn needs to be pointwise" + ) + assert outputs is not None assert len(outputs) == len( - input - ), f"expected combine_fn to return {len(input)} results but got {len(outputs)}" + xs + ), f"expected combine_fn to return {len(xs)} results but got {len(outputs)}" - for i, o in zip(input, outputs): + for i, o in zip(xs, outputs): o_meta = o.meta["tensor_meta"] assert o_meta.dtype == i.dtype, ( f"combine_fn output type mismatch, expected {i.dtype} " + f"but got {o_meta.dtype}" ) - assert ( - o_meta.shape == () - ), f"combine_fn must return a scalar tensor but got shape {o_meta.shape}" - assert ( - o_meta.shape == () - ), f"combine_fn must return a scalar tensor but got shape {o_meta.shape}" _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) - args = (combine_graph, input, dim) + args = (combine_graph, xs, dim) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) out_proxy = proxy_mode.tracer.create_proxy( "call_function", func_overload, proxy_args, {}, name="associative_scan" ) with disable_proxy_modes_tracing(): - out = [aten.clone(x) for x in input] + out = [aten.clone(x) for x in xs] return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) @associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) -def associative_scan_op_dense(combine_fn, input, dim): +def associative_scan_op_dense(combine_fn, xs, dim): raise NotImplementedError("associative_scan is not implemented for eager") @@ -183,45 +351,22 @@ def associative_scan_op_dense(combine_fn, input, dim): @associative_scan_op.py_impl(ProxyTorchDispatchMode) -def associative_scan_proxy_mode(mode, combine_fn, input, dim): - return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim) +def associative_scan_proxy_mode(mode, combine_fn, xs, dim): + return trace_associative_scan(mode, associative_scan_op, combine_fn, xs, dim) @associative_scan_op.py_impl(FakeTensorMode) -def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim): +def assoiciative_scan_fake_tensor_mode(mode, combine_fn, xs, dim): with mode: - return [x.clone() for x in input] + return [x.clone() for x in xs] @associative_scan_op.py_functionalize_impl -def associative_scan_functionalize(ctx, combine_fn, input, dim): - unwrapped_input = ctx.unwrap_tensors(input) +def associative_scan_functionalize(ctx, combine_fn, xs, dim): + unwrapped_xs = ctx.unwrap_tensors(xs) with ctx.redispatch_to_next() as m: - functional_combine_fn = ctx.functionalize(combine_fn) - ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim) + functional_combine_fn = ctx.functionalize( + _maybe_run_with_interpreter(combine_fn) + ) + ret = associative_scan_op(functional_combine_fn, unwrapped_xs, dim) return ctx.wrap_tensors(ret) - - -@associative_scan_op.py_impl(torch._C._functorch.TransformType.Vmap) -def associative_scan_batch_rule(interpreter, input, dim, combine_fn): - input_ = [get_unwrapped(x) for x in input] - input_bdims = [maybe_get_bdim(x) for x in input] - - batch_size = None - for inp, bdim in zip(input, input_bdims): - if bdim is not None: - batch_size = get_unwrapped(inp).shape[bdim] - - assert batch_size - input_unwrapped = [] - for x, bdim in zip(input, input_bdims): - unwrap = get_unwrapped(x) - if dim is None: - unwrap = unwrap.unsqueeze(0).expand(batch_size, *x.shape) - else: - unwrap = unwrap.movedim(bdim, 0) - input_unwrapped.append(unwrap) - - res = associative_scan_op(combine_fn, input_unwrapped, dim + 1) - lvl = interpreter.level() - return [_add_batch_dim(x, 0, lvl) for x in res] diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index bd84ec7f2eb600..232981f1f0192e 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils._pytree as pytree @@ -17,6 +18,149 @@ ) +def get_base(tensor): + if torch.is_inference_mode_enabled(): + return tensor._inference_mode_base + else: + return tensor._base + + +@dataclass +class ViewInfo: + base_index: int + size: Optional[Sequence[Union[int, torch.SymInt]]] = None + stride: Optional[Sequence[Union[int, torch.SymInt]]] = None + storage_offset: Optional[int] = None + # When is_view is false, the tensor is the base, and + # size, stride and storage_offset are all None. + is_view: bool = True + + def regenerate_view(self, bases_list: List[Tensor]): + if not self.is_view: + return bases_list[self.base_index] + + assert self.stride is not None + assert self.size is not None + assert self.storage_offset is not None + + return torch.as_strided( + bases_list[self.base_index], + self.size, + self.stride, + self.storage_offset, + ) + + +def write_view_information_to_args( + mutable_arg_names: List[str], + mutable_arg_types: List[torch.Type], + kwargs: Dict[str, Any], + arg_to_base_index: Dict[str, Any], +): + """ + This function writes the view information into kwargs. It reads mutable_args from kwargs. + and uses arg_to_base_index and tensor information to write ViewInfo into kwargs. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs: the original custom operator args. + arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that + corresponds to the input tensor + """ + + def write_single_view(prefix: str, tensor: Tensor, base_index: int): + assert f"{prefix}_base_index" not in kwargs + assert f"{prefix}_size" not in kwargs + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + if tensor is None: + kwargs[f"{prefix}_base_index"] = None + elif get_base(tensor) is None: + # if the tensor is the base (not view), for simplicity we do not serialize view meta. + kwargs[f"{prefix}_base_index"] = base_index + else: + kwargs[f"{prefix}_base_index"] = base_index + kwargs[f"{prefix}_size"] = tensor.size() + kwargs[f"{prefix}_stride"] = tensor.stride() + kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() + + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + arg = kwargs[arg_name] + if isinstance(arg_type, torch.ListType): + if arg is None: + kwargs[f"_{arg_name}_length"] = None + + kwargs[f"_{arg_name}_length"] = len(arg) + for i, elem in enumerate(arg): + write_single_view( + f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] + ) + + elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + write_single_view( + f"_{arg_name}", + kwargs[arg_name], + arg_to_base_index.get(arg_name, None), + ) + else: + raise RuntimeError(f"Unsupported type {arg_type}") + + +# Returns a dict of arg_name -> ViewInfo | [ViewInfo] +def read_view_information_from_args( + mutable_arg_names: List[str], + mutable_arg_types: List[torch.Type], + kwargs: Dict[str, Any], + all_bases: List[Tensor], +): + """ + This reads the view information added by `write_view_information_to_args` from kwargs, pop them, + and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg + to its view information. + mutable_arg_names: mutable custom operator arg names. + mutable_arg_types: mutable custom operator arg types. + kwargs : args of auto_functionalize(custom_op, kwargs) + """ + + def get_arg(name): + return kwargs.pop(name) + + def read_single_view(prefix): + base_index = get_arg(f"{prefix}_base_index") + if base_index is None: + return None + elif f"{prefix}_size" not in kwargs: + assert f"{prefix}_stride" not in kwargs + assert f"{prefix}_storage_offset" not in kwargs + + # This means that the argument is the base tensor + return ViewInfo(base_index, all_bases[base_index], is_view=False) + + else: + size = get_arg(f"{prefix}_size") + stride = get_arg(f"{prefix}_stride") + storage_offset = get_arg(f"{prefix}_storage_offset") + return ViewInfo(base_index, size, stride, storage_offset, is_view=True) + + args_view_info: Dict[str, Any] = {} + for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): + if isinstance(arg_type, torch.ListType): + length = get_arg(f"_{arg_name}_length") + if length is None: + # The whole list is None. + args_view_info[arg_name] = None + else: + args_view_info[arg_name] = [ + read_single_view(f"_{arg_name}_{i}") for i in range(length) + ] + + elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): + args_view_info[arg_name] = read_single_view(f"_{arg_name}") + else: + raise RuntimeError(f"Unsupported type {arg_type}") + return args_view_info + + # NOTE: [auto-functionalizing custom ops] # Users may wish to torch.compile custom ops that mutate their inputs. # torch.compile will automatically support this op without anyone needing @@ -34,6 +178,9 @@ # This HOP effectively runs the functional version of the op when # called: it clones inputs that will be mutated, runs the op, and # then returns (output, Tensors with the new values) +# +# auto_functionalize_v2 is an improved version of auto_functionalize that better handle +# re-inplacing views. class AutoFunctionalized(HigherOrderOperator): @@ -71,6 +218,38 @@ def __call__( auto_functionalized = AutoFunctionalized() auto_functionalized.__module__ = "torch.ops.higher_order" +auto_functionalized.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) + + +class AutoFunctionalizedV2(HigherOrderOperator): + """auto_functionalized_v2(_mutable_op, **kwargs) + + This HOP runs a "functional" version of _mutable_op. + Unlike AutoFunctionalized, this version is improved to better handle + view tensors. This version is only used in non export mode. + """ + + def __init__(self) -> None: + super().__init__("auto_functionalized_v2") + + def __call__( + self, + /, + _mutable_op: OpOverload, + **kwargs: Any, + ) -> Tuple[Any, Tuple[Tensor, ...]]: + assert can_auto_functionalize(_mutable_op) + assert isinstance(kwargs, dict) + return super().__call__(_mutable_op, **kwargs) + + +auto_functionalized_v2 = AutoFunctionalizedV2() +auto_functionalized_v2.__module__ = "torch.ops.higher_order" + +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU) +auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA) + def can_auto_functionalize(op: OperatorBase) -> bool: if not isinstance(op, OpOverload): @@ -120,85 +299,23 @@ def can_auto_functionalize(op: OperatorBase) -> bool: return True -@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) -def auto_functionalized_dense( - _mutable_op: OpOverload, - _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, - **kwargs: Any, -) -> Tuple[Any, Tuple[Tensor, ...]]: - new_kwargs = dict(**kwargs) - result = [] - - _mutable_args_names = get_mutable_arg_names(_mutable_op) - for name in _mutable_args_names: - if ( - _only_clone_these_tensors is not None - and name not in _only_clone_these_tensors - ): - new_kwargs[name] = kwargs[name] - else: - new_kwargs[name] = ( - [clone_preserve_strides(x) for x in kwargs[name]] - if kwargs[name] is not None and isinstance(kwargs[name], list) - else clone_preserve_strides(kwargs[name]) - if kwargs[name] is not None - else None - ) - result.append(new_kwargs[name]) - out = _mutable_op(**new_kwargs) - - if isinstance(out, tuple): - return (*out, *result) # type: ignore[return-value] - else: - return (out, *result) # type: ignore[return-value] - - -@auto_functionalized.py_impl(FakeTensorMode) -def auto_functionalized_fake( - mode, - _mutable_op: OpOverload, - **kwargs: Any, -) -> Tuple[Any, Tuple[Tensor, ...]]: - with mode: - result = auto_functionalized_dense(_mutable_op, **kwargs) - return result - - -@auto_functionalized.py_impl(ProxyTorchDispatchMode) -def auto_functionalized_proxy( - mode, - _mutable_op: OpOverload, - **kwargs: Any, -) -> Tuple[Any, Tuple[Tensor, ...]]: - with disable_proxy_modes_tracing(): - out = auto_functionalized(_mutable_op, **kwargs) - - proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) - out_proxy = mode.tracer.create_proxy( - "call_function", - auto_functionalized, - (_mutable_op,), - proxy_kwargs, - ) - result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) - return result - - -auto_functionalized.fallthrough(DispatchKey.AutogradCPU) -auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) - - -def get_mutable_arg_names(op: OpOverload) -> List[str]: +def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]: """ Returns the list of argument names that get mutated according to the - schema. + schema and their types. """ mutable_args_names = [ arg.name for arg in op._schema.arguments if arg.alias_info is not None and arg.alias_info.is_write ] - return mutable_args_names + + mutable_args_types = [ + arg.type + for arg in op._schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + return mutable_args_names, mutable_args_types def do_auto_functionalize( @@ -245,7 +362,7 @@ def do_auto_functionalize( ) # List of the name of args that get mutated (according to the schema) - mutable_args_names = get_mutable_arg_names(op) + mutable_args_names, _ = get_mutable_args(op) unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[ : -len(mutable_args_names) @@ -290,9 +407,307 @@ def sync_update(o, orig_arg): return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] +def do_auto_functionalize_v2( + op: OpOverload, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], +) -> Any: + from torch._subclasses.functional_tensor import PythonFunctionalizeAPI + + ctx = PythonFunctionalizeAPI() + + # All of the (args, kwargs), but all as kwargs. The names for the + # args come from the schema. This makes it easier for us to work with them. + normalized_kwargs = {} + + schema = op._schema + for idx, arg in enumerate(schema.arguments): + # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema + if arg.name in kwargs: + normalized_kwargs[arg.name] = kwargs[arg.name] + elif idx < len(args): + # if its out of bounds we don't need to do anything + # as it means the the optional arg was passed with its default + # value + normalized_kwargs[arg.name] = args[idx] + else: + normalized_kwargs[arg.name] = arg.default_value + + # List of the name of args that get mutated (according to the schema) + mutable_args_names, mutable_args_types = get_mutable_args(op) + + # A list of all bases of mutable args without duplication + all_bases = [] + all_bases_addresses: list[int] = [] + + # Map arg_name to the index of its base in all_bases. + arg_to_base_index: Dict[str, Any] = {} + + def update_dict(tensor, arg_name, index=None): + base = tensor if get_base(tensor) is None else get_base(tensor) + + def set_result(base_index): + if index is None: + arg_to_base_index[arg_name] = base_index + else: + arg_to_base_index[arg_name][index] = base_index + + if not all_bases_addresses.__contains__(base._cdata): + all_bases_addresses.append(base._cdata) + all_bases.append(base) + set_result(len(all_bases) - 1) + else: + set_result(all_bases_addresses.index(base._cdata)) + + for arg_name in mutable_args_names: + arg = normalized_kwargs[arg_name] + if arg is None: + continue + + if isinstance(arg, list): + arg_to_base_index[arg_name] = {} + for i, tensor in enumerate(arg): + if tensor is None: + arg_to_base_index[arg_name].append(None) + continue + + update_dict(tensor, arg_name, i) + + else: + update_dict(arg, arg_name) + + # add view_meta for each args into unwrapped_kwargs. + write_view_information_to_args( + mutable_args_names, + mutable_args_types, + normalized_kwargs, + arg_to_base_index, + ) + + # remove mutated args from the kwargs (its a function of _all_bases now) + for arg_name in mutable_args_names: + del normalized_kwargs[arg_name] # type: ignore[arg-type] + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: + warnings.warn( + "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " + "Please consider using a different name for this argument to avoid potential issues." + ) + all_basis_unwrapped = ctx.unwrap_tensors(all_bases) + + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized_v2( + op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type] + ) + + unwrapped_actual_out: Union[Any, Tuple[Any]] = ( + unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)] + ) + + unwrapped_mutable_out = ( + [] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :] + ) + + if len(op._schema.returns) == 0: + assert unwrapped_actual_out[0] is None + unwrapped_actual_out = None + elif len(op._schema.returns) == 1: + assert len(unwrapped_actual_out) == 1 + unwrapped_actual_out = unwrapped_actual_out[0] + else: + assert len(unwrapped_actual_out) == len(op._schema.returns) + + for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out): + # Can be None if input was `Tensor(a!)?` + if unwrapped_out is None: + continue + + # We only handle Tensor or List[Tensor] here for now. + def sync_update(o, orig_arg): + ctx.replace(orig_arg, o) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + if isinstance(unwrapped_out, torch.Tensor): + sync_update(unwrapped_out, orig_arg) + elif isinstance(unwrapped_out, list) and all( + isinstance(o, torch.Tensor) for o in unwrapped_out + ): + assert len(orig_arg) == len(unwrapped_out) + for orig_a, o in zip(orig_arg, unwrapped_out): + sync_update(o, orig_a) + else: + raise RuntimeError( + f"unsupported type for auto-functionalization: {unwrapped_out}" + ) + + return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] + + +# auto_functionalize functions +@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_dense( + _mutable_op: OpOverload, + _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + new_kwargs = dict(**kwargs) + result = [] + + _mutable_args_names, _ = get_mutable_args(_mutable_op) + for name in _mutable_args_names: + if ( + _only_clone_these_tensors is not None + and name not in _only_clone_these_tensors + ): + new_kwargs[name] = kwargs[name] + else: + new_kwargs[name] = ( + [clone_preserve_strides(x) for x in kwargs[name]] + if kwargs[name] is not None and isinstance(kwargs[name], list) + else clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) + result.append(new_kwargs[name]) + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *result) # type: ignore[return-value] + else: + return (out, *result) # type: ignore[return-value] + + +@auto_functionalized.py_impl(FakeTensorMode) +def auto_functionalized_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + @auto_functionalized.py_functionalize_impl def auto_functionalized_func(ctx, _mutable_op, **kwargs): unwrapped_kwargs = ctx.unwrap_tensors(kwargs) with ctx.redispatch_to_next(): result = auto_functionalized(_mutable_op, **unwrapped_kwargs) return ctx.wrap_tensors(result) + + +# auto_functionalized_v2 functions +@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd) +def auto_functionalized_v2_dense( + _mutable_op: OpOverload, + _only_clone_these_bases: Optional[Tuple[int, ...]] = None, + **kwargs: Any, +) -> Tuple[Any, Tuple[Tensor, ...]]: + all_bases: List[Tensor] = kwargs.pop("_all_bases", []) + mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op) + args_view_info = read_view_information_from_args( + mutable_args_names, mutable_args_types, kwargs, all_bases + ) + + if _only_clone_these_bases is None: + _only_clone_these_bases = tuple(range(len(all_bases))) + + def maybe_copy(i, t): + if t is None: + return None + if i in _only_clone_these_bases: + return clone_preserve_strides(t) + else: + return t + + all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)] + + # create new args + new_kwargs = dict(**kwargs) + + # re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs. + for arg_name in mutable_args_names: + if args_view_info[arg_name] is None: + new_kwargs[arg_name] = None + elif isinstance(args_view_info[arg_name], list): + new_kwargs[arg_name] = [] + for i, elem in enumerate(args_view_info[arg_name]): + if elem is None: + new_kwargs[arg_name].append(None) + else: + view_info = args_view_info[arg_name][i] + new_kwargs[arg_name].append( + view_info.regenerate_view(all_bases_new) + ) + else: + new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view( + all_bases_new + ) + + out = _mutable_op(**new_kwargs) + + if isinstance(out, tuple): + return (*out, *all_bases_new) # type: ignore[return-value] + else: + return (out, *all_bases_new) # type: ignore[return-value] + + +@auto_functionalized_v2.py_impl(FakeTensorMode) +def auto_functionalized_v2_fake( + mode, + _mutable_op: OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with mode: + result = auto_functionalized_v2_dense(_mutable_op, **kwargs) + return result + + +@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode) +def auto_functionalized_v2_proxy( + mode, + _mutable_op: OpOverload, + **kwargs: Dict[str, Any], +) -> Tuple[Any, Tuple[Tensor, ...]]: + with disable_proxy_modes_tracing(): + out = auto_functionalized_v2(_mutable_op, **kwargs) + + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", + auto_functionalized_v2, + (_mutable_op,), + proxy_kwargs, + ) + result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + return result + + +@auto_functionalized_v2.py_functionalize_impl +def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs): + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + with ctx.redispatch_to_next(): + result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs) + return ctx.wrap_tensors(result) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 49a1a4cd45fa39..0467e2899adc28 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -18,6 +18,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, _set_compilation_env, reenter_make_fx, unique_graph_id, @@ -27,6 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, disable_proxy_modes_tracing, ProxyTorchDispatchMode, @@ -128,6 +130,10 @@ def false_fn(x: torch.Tensor): if torch.compiler.is_dynamo_compiling(): return cond_op(pred, true_fn, false_fn, operands) + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + if isinstance(pred, (bool, int, float)): log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." @@ -168,12 +174,15 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - with _set_compilation_env(): - with torch._dynamo.utils.disable_cache_limit(): - with _temp_remove_pre_dispatch_torch_function_mode(): - return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( - pred, true_fn, false_fn, operands - ) + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(), _temp_remove_pre_dispatch_torch_function_mode(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( + pred, true_fn, false_fn, operands + ) def create_fw_bw_graph_branches(true_fn, false_fn, *operands): @@ -449,8 +458,8 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs): unwrapped_inputs = ctx.unwrap_tensors(inputs) unwrapped_pred = ctx.unwrap_tensors(pred) with ctx.redispatch_to_next() as m: - functional_true = ctx.functionalize(true_fn) - functional_false = ctx.functionalize(false_fn) + functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn)) + functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn)) pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch for branch in [functional_true, functional_false]: if _has_potential_branch_input_mutation( diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py index e6ba7f65cf182f..651b3d6e2caa85 100644 --- a/torch/_higher_order_ops/effects.py +++ b/torch/_higher_order_ops/effects.py @@ -2,6 +2,7 @@ # mypy: allow-untyped-defs from enum import Enum from typing import Any, Dict, Optional, Tuple, Union +from weakref import WeakKeyDictionary import torch import torch.utils._pytree as pytree @@ -23,11 +24,12 @@ class _EffectType(Enum): OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload] -# TODO(ivankobzarev): Make SIDE_EFFECTS dictionary WeakKeyDictionary as operator can go out of scope -SIDE_EFFECTS: Dict[OpType, _EffectType] = { - torch.ops.aten._print.default: _EffectType.ORDERED, - call_torchbind: _EffectType.ORDERED, -} +SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary( + { + torch.ops.aten._print.default: _EffectType.ORDERED, + call_torchbind: _EffectType.ORDERED, + } +) def _register_effectful_op(op: OpType, effect: _EffectType): @@ -125,8 +127,7 @@ def get_effect_key(op, args, kwargs) -> Optional[_EffectType]: def new_token_tensor() -> torch.Tensor: - # Use dtype bool to not affect Inductor dtype promotions - return torch.tensor([], dtype=torch.bool) + return torch.tensor([]) @with_effects.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/_higher_order_ops/executorch_call_delegate.py b/torch/_higher_order_ops/executorch_call_delegate.py index 4ef87f36080462..a6ee5205ff4ebb 100644 --- a/torch/_higher_order_ops/executorch_call_delegate.py +++ b/torch/_higher_order_ops/executorch_call_delegate.py @@ -29,6 +29,9 @@ class ExecutorchCallDelegate(HigherOrderOperator): def __init__(self): super().__init__("executorch_call_delegate") + def __call__(self, lowered_module, *args): + return super().__call__(lowered_module, *args) + executorch_call_delegate = ExecutorchCallDelegate() executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index d66f8f4ba54d84..ed99ae7beca12a 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import math -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Sequence, Tuple, Union import torch import torch.utils._pytree as pytree @@ -23,6 +23,53 @@ from torch.overrides import TorchFunctionMode +# Duplicate of _inductor/kernel/flex_attention.py to avoid circular import +def _construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch.Tensor: + """ + Create a new tensor with the same data and shape as the input, + but with strides permuted based on the input tensor's stride order. + + Args: + out (torch.Tensor): The output tensor of attention. + query_strides (List[int]): The stride order of the input query tensor + + Returns: + torch.Tensor: A new tensor with same shape and data as the input, + but with strides permuted based on the query tensor's stride order. + """ + from torch._inductor.ir import get_stride_order, stride_order2fill_order + + stride_order = get_stride_order(query_strides) + fill_order = stride_order2fill_order(stride_order) + assert out.storage_offset() == 0, "Only support storage_offset == 0" + out_strides = _construct_strides(out.shape, fill_order) + new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) + new_out.copy_(out) + return new_out + + class TransformGetItemToIndex(TorchFunctionMode): # This is needed since we want to support calling # A[q_idx], where q_idx is a scalar tensor in score_mod. @@ -30,7 +77,7 @@ class TransformGetItemToIndex(TorchFunctionMode): # scalar and create a view. We do not want that behavior in this case, so we # use this torchfunctionmode to override that behavior for score_mod # wherever we're running it. - def __torch_function__(self, func, types, args, kwargs=None): + def __torch_function__(self, func, types, args=(), kwargs=None): if func == torch.Tensor.__getitem__: index_args = pytree.tree_leaves(args[1]) if all(isinstance(x, torch.Tensor) for x in index_args): @@ -192,6 +239,13 @@ def math_attention( value = torch.repeat_interleave(value, G, dim=1) key = torch.repeat_interleave(key, G, dim=1) + Bq, Bkv = query.size(0), key.size(0) + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + + key = key.expand((Bq, *key.size()[1:])) + value = value.expand((Bq, *value.size()[1:])) + _, post_mod_scores = _math_attention_inner( query, key, @@ -207,7 +261,7 @@ def math_attention( # Set fully masked rows' sumexp to 0.0 logsumexp = post_mod_scores.logsumexp(dim=-1) masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1) - logsumexp = torch.where(masked_rows, 0.0, logsumexp) + logsumexp = torch.where(masked_rows, -float("inf"), logsumexp) post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1) @@ -237,7 +291,7 @@ def sdpa_dense( score_mod_other_buffers, mask_mod_other_buffers, ) - out = out.contiguous() + out = _permute_strides(out, query.stride()) return out, lse @@ -425,7 +479,9 @@ def flex_attention_fake_tensor_mode( batch_size, num_heads, seq_len_q, dtype=torch.float32 ) out_shape = (batch_size, num_heads, seq_len_q, v_head_dim) - return query.new_empty(out_shape), logsumexp + out = query.new_empty(out_shape) + out = _permute_strides(out, query.stride()) + return out, logsumexp # ---------------------------- Autograd Implementation ---------------------------- @@ -686,6 +742,18 @@ def sdpa_dense_backward( score_mod_other_buffers: Tuple, mask_mod_other_buffers: Tuple, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Get outputs before calling repeat interleave + actual_grad_query = torch.empty_like(query) + actual_grad_key = torch.empty_like(key) + actual_grad_value = torch.empty_like(value) + + Bq, Bkv = query.size(0), key.size(0) + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + + key = key.expand((Bq, *key.size()[1:])) + value = value.expand((Bq, *value.size()[1:])) + G = query.size(1) // key.size(1) key = torch.repeat_interleave(key, G, dim=1) value = torch.repeat_interleave(value, G, dim=1) @@ -705,7 +773,9 @@ def sdpa_dense_backward( score_mod_other_buffers, mask_mod_other_buffers, ) + masked_out_rows = logsumexp == -float("inf") softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) + softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores) grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out @@ -765,7 +835,20 @@ def sdpa_dense_backward( grad_key = torch.sum(grad_key, 2, keepdim=False) grad_value = torch.sum(grad_value, 2, keepdim=False) - return grad_query.contiguous(), grad_key.contiguous(), grad_value.contiguous() + if Bq != Bkv: + assert ( + Bq > 1 and Bkv == 1 + ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + + # Reduce DK, DV along broadcasted batches. + grad_key = torch.sum(grad_key, 0, keepdim=True) + grad_value = torch.sum(grad_value, 0, keepdim=True) + + actual_grad_query.copy_(grad_query) + actual_grad_key.copy_(grad_key) + actual_grad_value.copy_(grad_value) + + return actual_grad_query, actual_grad_key, actual_grad_value def trace_flex_attention_backward( diff --git a/torch/_higher_order_ops/hints_wrap.py b/torch/_higher_order_ops/hints_wrap.py new file mode 100644 index 00000000000000..c211d405614636 --- /dev/null +++ b/torch/_higher_order_ops/hints_wrap.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +# used for wrapping a function/op with context hints +class HintsWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("hints_wrapper") + + def __call__(self, body_fn, args, kwargs, hints): + r""" + Call implementation of hints_wrapper + + Args: + body_fn (Callable): A callable function that is within the scope + that is being traced. + + args (Tuple of torch.Tensor/int/float/bool): A tuple of inputs to + body_fn. + + kwargs (dict): Keyword argument to the body_fn. + + hints (dict): A dict of context hints which could be passed to + backend compiler. + """ + if not isinstance(args, tuple): + raise RuntimeError(f"args must be a tuple, got {type(args)}") + + if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): + raise RuntimeError( + "args must be a tuple of tensors, ints, floats, or bools, got " + f"{args}" + ) + + if not isinstance(kwargs, dict): + raise RuntimeError(f"kwargs must be a dict, got {type(kwargs)}") + + if len(kwargs) > 0: + raise RuntimeError( + f"kwargs except for hints are not supported, got {kwargs}" + ) + + if not isinstance(hints, dict): + raise RuntimeError(f"hints must be a dict, got {type(hints)}") + + for k, v in hints.items(): + if not isinstance(k, str): + raise RuntimeError(f"hints key must be a str, got {k}.") + + if not isinstance(v, (int, float, bool, str)): + raise RuntimeError( + "hints must be a dict containing int, float, bool or str " + f"value, got value {v} for key {k}." + ) + + return super().__call__(body_fn, args, kwargs, hints) + + +hints_wrapper = HintsWrapper() + + +@hints_wrapper.py_impl(DispatchKey.CompositeExplicitAutograd) +def hints_wrapper_dense(body_fn, args, kwargs, hints): + return body_fn(*args, **kwargs) + + +hints_wrapper.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(hints_wrapper, deferred_error=True) +) + + +@hints_wrapper.py_impl(FakeTensorMode) +def hints_wrapper_fake_tensor_mode(mode, body_func, args, kwargs, hints): + flat_args = pytree.tree_leaves(args) + with mode: + return body_func(*flat_args, **kwargs) + + +@hints_wrapper.py_functionalize_impl +def hints_wrapper_functionalize(ctx, body_fn, args, kwargs, hints): + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + unwrapped_hints = ctx.unwrap_tensors(hints) + with ctx.redispatch_to_next(): + functional_body_fn = ctx.functionalize(body_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + if _has_potential_branch_input_mutation( + functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "body_fn of hints_wrapper might be modifying the input!" + ) + if _has_potential_branch_input_alias( + functional_body_fn, unwrapped_args, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "body_fn of hints_wrapper might be aliasing the input!" + ) + outputs = hints_wrapper( + functional_body_fn, + unwrapped_args, + unwrapped_kwargs, + unwrapped_hints, + ) + return ctx.wrap_tensors(outputs) + + +def trace_hints_wrapper(proxy_mode, hints_wrapper, body_fn, args, kwargs, hints): + flat_args = tuple(pytree.tree_leaves(args)) + body_graph = reenter_make_fx(body_fn)(*flat_args, **kwargs) + + _, body_graph_name = unique_graph_id(proxy_mode, prefix="hints_wrapper_body_graph") + proxy_mode.tracer.root.register_module(body_graph_name, body_graph) + + new_args: tuple = (body_graph, flat_args, {}) + # merge hints into kwargs + new_kwargs = {} + new_kwargs["hints"] = hints + + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_args) + proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, new_kwargs) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", hints_wrapper, proxy_args, proxy_kwargs, name="hints_wrapper" + ) + + out = body_fn(*flat_args, **kwargs) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@hints_wrapper.py_impl(ProxyTorchDispatchMode) +def inner(proxy_mode, body_fn, args, kwargs, hints): + if proxy_mode.enable_tracing: + return trace_hints_wrapper( + proxy_mode, hints_wrapper, body_fn, args, kwargs, hints + ) + else: + return hints_wrapper(body_fn, args, kwargs, hints) diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index b91a7a3701d848..d57d68d5e473f7 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -7,6 +7,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, reenter_make_fx, UnsupportedAliasMutationException, ) @@ -32,6 +33,9 @@ # TODO: We add this to prevent dymamo from tracing into map_wrapper, # remove the wrapper call when it's ready. class MapWrapper(HigherOrderOperator): + def __init__(self): + super().__init__("map") + def __call__(self, xs, *args): return map_wrapper(xs, *args) @@ -40,8 +44,11 @@ class MapImpl(HigherOrderOperator): def __init__(self): super().__init__("map_impl") + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + -map = MapWrapper("map") +map = MapWrapper() map_impl = MapImpl() @@ -237,7 +244,7 @@ def map_fake_tensor_mode(mode, f, xs, args): def map_functionalize(ctx, f, xs, pos_args): unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_args = ctx.unwrap_tensors(pos_args) - wrapped_fn = ctx.functionalize(f) + wrapped_fn = ctx.functionalize(_maybe_run_with_interpreter(f)) with ctx.redispatch_to_next(): with disable_proxy_modes_tracing(): diff --git a/torch/_higher_order_ops/run_const_graph.py b/torch/_higher_order_ops/run_const_graph.py index 774f3ebc0e2a6f..1f49ee28394a1c 100644 --- a/torch/_higher_order_ops/run_const_graph.py +++ b/torch/_higher_order_ops/run_const_graph.py @@ -12,6 +12,9 @@ class RunConstGraph(HigherOrderOperator): def __init__(self): super().__init__("run_const_graph") + def __call__(self, *args): + return super().__call__(*args) + run_const_graph = RunConstGraph() diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py new file mode 100644 index 00000000000000..d66cff067f6682 --- /dev/null +++ b/torch/_higher_order_ops/scan.py @@ -0,0 +1,448 @@ +# mypy: allow-untyped-defs +import functools +import itertools +from typing import Callable, List, Tuple + +import torch +import torch._prims_common as utils +import torch._subclasses.functional_tensor +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import ( + _has_potential_branch_input_alias, + _has_potential_branch_input_mutation, + _set_compilation_env, + autograd_not_implemented, + reenter_make_fx, + unique_graph_id, + UnsupportedAliasMutationException, +) +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import _get_current_dispatch_mode + + +aten = torch._ops.ops.aten + + +def wrap_combine_fn_flat( + *args, combine_fn, spec_init, spec_xs, num_init_leaves, num_inp_leaves +): + assert len(args) == (num_init_leaves + num_inp_leaves) + carry = pytree.tree_unflatten(args[:num_init_leaves], spec_init) + xs = pytree.tree_unflatten(args[num_init_leaves:], spec_xs) + carry, combined = combine_fn(carry, xs) + carry_flat = pytree.tree_leaves(carry) + combined_flat = pytree.tree_leaves(combined) + assert num_init_leaves == len(carry_flat) + return (carry_flat, combined_flat) + + +def scan( + combine_fn: Callable[ + [pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree] + ], + init: pytree.PyTree, + xs: pytree.PyTree, + /, + *, + dim: int = 0, + reverse: bool = False, +) -> Tuple[pytree.PyTree, pytree.PyTree]: + r""" + Performs an inclusive scan with a combine function. + + .. warning:: + `torch.scan` is a prototype feature in PyTorch. It currently + does not support autograd and you may run into miscompiles. + Read more about feature classification at: + https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype + + Args: + combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> (Tensor, Tensor)``, + or if xs is a pytree ``(pytree, pytree) -> (pytree, pytree)``. + The first input to ``combine_fn`` is the previous or initial scan carry + and the second input element to ``combine_fn`` is a slice of the input along dim. + The first output element of ``combine_fn`` is the next scan carry + and the second output of ``combine_fn`` represents a slice of the output. + This function must be pure, i.e., no lifted arguments are supported at the moment + and may not have any side effects. + init (torch.Tensor or pytree with tensor leaves): The inital scan carry, a tensor, or nested pytree of tensors. + The ``init`` is expected to have the same pytree structure as the first output element (i.e. carry) + of ``combine_fn``. + xs (torch.Tensor or pytree with tensor leaves): The input tensor, or nested pytree of tensors. + + Kwargs: + dim (int): the dimension to scan over, default 0. + reverse (bool): A boolean stating if the scan should be reversed with respect to ``dim``, default ``False``. + + Returns: + final_carry (torch.Tensor or pytree with tensor leaves), + the final carry of the scan operation with same pytree structure as init. + out (torch.Tensor or pytree with tensor leaves), + each tensor leaf is a stacked output along dim, where each slice is the output of a scan iteration. + + Example:: + + def add(x: torch.Tensor, y: torch.Tensor): + next_carry = y = x + y + return next_carry, y + + i0 = torch.zeros(1) + xs = torch.arange(1, 5) + # returns torch.tensor([10]), torch.tensor([1., 3., 6., 10.]) + last_carry, cumsum = scan(add, init=i0, xs=xs) + + + """ + if not callable(combine_fn): + raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}") + if not isinstance(dim, int): + raise RuntimeError("Dim must be an int, but got " + str(type(dim))) + if not isinstance(reverse, bool): + raise RuntimeError("Reverse must be a bool, but got " + str(type(reverse))) + + # TODO: Support closures/nn_modules in order to be able represent RNNs with scan + # TODO: Support _inductor lowering + # TODO: Support Autograd + # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. + + # Dynamo is expecting a callable with "__code__" attribute. + # We cannot directly pass cond_op to it. So we wrap it in a dummy function. + def _scan_op_wrapper(*args, **kwargs): + return scan(*args, **kwargs) + + if not torch._dynamo.is_compiling(): + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + + with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( + combine_fn, init, xs, dim=dim, reverse=reverse + ) + + leaves_init, spec_init = pytree.tree_flatten(init) + leaves_xs, spec_xs = pytree.tree_flatten(xs) + + if len(leaves_init) == 0: + raise RuntimeError("Init tensors must be provided") + if any(not isinstance(x, torch.Tensor) for x in leaves_init): + raise RuntimeError("All init leaves must be a Tensor") + if any(not isinstance(x, torch.Tensor) for x in leaves_xs): + raise RuntimeError("All xs leaves must be a Tensor") + if any(x.shape[dim] == 0 for x in leaves_xs): + raise RuntimeError("All xs leaves must have a scan dimension > 0") + + if len(leaves_xs) > 0: + shape = leaves_xs[0].shape + ndim = len(shape) + dim = utils.canonicalize_dim(ndim, dim) + + out = combine_fn( + pytree.tree_unflatten(leaves_init, spec_init), + pytree.tree_unflatten( + [aten.slice(elem, dim, 0, 1, 1) for elem in leaves_xs], spec_xs + ), + ) + + # The first output needs to have the same pytree as init + carry_leaves = pytree.tree_leaves(out[0]) + if len(carry_leaves) != len(leaves_init): + raise RuntimeError( + "The number of leaves of the pytree of the new carry produced by the operator\ + needs to match the length of the pytree of the init" + ) + if any( + in_l.shape != out_l.shape for in_l, out_l in zip(leaves_init, carry_leaves) + ): + raise RuntimeError( + "The pytree of the new carry produced by the operator needs to match the pytree of the init" + ) + + # There are no pytree restrictions on the second output of the operator + out_leaves, tree_out = pytree.tree_flatten(out[1]) + + combine_fn = functools.partial( + wrap_combine_fn_flat, + combine_fn=combine_fn, + spec_init=spec_init, + spec_xs=spec_xs, + num_init_leaves=len(leaves_init), + num_inp_leaves=len(leaves_xs), + ) + + result_carry, result_flat = scan_op( + combine_fn, leaves_init, leaves_xs, dim, reverse + ) + + return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten( + result_flat, tree_out + ) + + else: + return pytree.tree_unflatten(leaves_init, spec_init), xs + + +class ScanOp(HigherOrderOperator): + def __init__(self): + super().__init__("scan") + + def __call__(self, combine_fn, init, xs, dim, reverse): + return super().__call__(combine_fn, init, xs, dim, reverse) + + +scan_op = ScanOp() + + +def generic_scan(operator, init, xs, dim=0, reverse=False): + def _scan(init, xs): + """Perform scan on `elems` using `elems_init.""" + carry = init + if len(xs) == 0: + return carry, [] + + num_elems = xs[0].shape[dim] + if reverse: + ind = num_elems - 1 + else: + ind = 0 + + # Compute dummy shapes for the pre-allocation + dummy_carry, dummy_out = operator( + *carry, *[aten.slice(elem, dim, 0, 1, 1) for elem in xs] + ) + output_scanned_dim = dummy_out[0].shape[dim] + + # Pre-alocate + # outs -> Output matrix + # idxs -> Index matrix for scatter_ + outs, outs_idxs = zip( + *[ + [ + torch.zeros( + list(e.size())[:dim] + + [list(e.size())[dim] * num_elems] + + list(e.size())[dim + 1 :], + dtype=e.dtype, + device=e.device, + ), + torch.cat( + [ + id * t + for id, t in zip( + range(output_scanned_dim), + torch.tensor_split( + torch.ones_like(e, dtype=torch.int64), + output_scanned_dim, + dim=dim, + ), + ) + ], + dim, + ), + ] + for i, e in enumerate(dummy_out) + ] + ) + + def store_in_mat(mat, out, d, index, index_modifier): + # Store the intermediate out in the outs matrix + for o, x, idx in zip(mat, out, index): + o.scatter_(d, idx + index_modifier, x) + + def cond(i, n, r): + if (r and i < 0) or (not r and i > (n - 1)): + return False + else: + return True + + def op(i): + if reverse: + return i - 1 + else: + return i + 1 + + while cond(ind, num_elems, reverse): + carry, out = operator( + *carry, + *[aten.slice(elem, dim, ind, ind + 1, 1) for elem in xs], + ) + + # Store the inits in the outs matrix. + store_in_mat(outs, out, dim, outs_idxs, ind * output_scanned_dim) + + ind = op(ind) + + return (carry, list(outs)) + + scans = _scan(init, xs) + return scans + + +def make_expanded_output_shape(dim, scan_length, shapes, use_sh=False): + expanded_shapes = [ + tuple( + (s if use_sh else -1) if i != dim else scan_length for i, s in enumerate(sh) + ) + for sh in shapes + ] + return expanded_shapes + + +def trace_scan( + proxy_mode, + func_overload, + combine_fn: Callable, + init: List[torch.Tensor], + xs: List[torch.Tensor], + dim: int, + reverse: bool, +): + with disable_proxy_modes_tracing(): + sample_inits = [ + torch.empty_like( + x_init, + dtype=x_init.dtype, + device=x_init.device, + requires_grad=x_init.requires_grad, + ) + for x_init in init + ] + sample_xs = [ + torch.empty_like( + aten.slice(x, dim, 0, 1, 1), + dtype=x.dtype, + device=x.device, + requires_grad=x.requires_grad, + ) + for x in xs + ] + combine_graph = reenter_make_fx(combine_fn)(*sample_inits, *sample_xs) + + outputs = None + for node in combine_graph.graph.nodes: + if node.op == "output": + assert outputs is None + assert len(node.args) == 1 + outputs = node.args[0] + + assert outputs is not None + if len(outputs) != 2: + raise RuntimeError( + f"Expected to return 2 outputs: carry, out_matrix, but got:" + f"\n {len(outputs)} elements" + ) + + for ini, carry in zip(init, outputs[0]): + ini_meta = ini + carry_meta = carry.meta["tensor_meta"] + carry_val = carry.meta["val"] + if ( + carry_val.device != ini_meta.device + or carry_meta.dtype != ini_meta.dtype + or carry_meta.shape != ini_meta.shape + ): + raise RuntimeError( + f"Expected metadata of the combine_fn result {carry_meta} to be the same as " + + f"the metadata of init with {ini_meta}" + ) + + _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph") + + proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) + + args = (combine_graph, init, xs, dim, reverse) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="scan" + ) + + with disable_proxy_modes_tracing(): + scan_length = xs[0].shape[dim] + fake_out_shapes = make_expanded_output_shape( + dim, scan_length, [o.meta["val"].size() for o in outputs[1]] + ) + + def expand_tensor(t, sh): + if isinstance(t, torch.Tensor): + return t.expand(*sh) + return t + + expanded_outs = [ + pytree.tree_map(expand_tensor, t.meta["val"], sh) + for t, sh in zip(outputs[1], fake_out_shapes) + ] + out = (init, expanded_outs) + + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) +def scan_op_dense(combine_fn, init, xs, dim, reverse): + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return generic_scan(combine_fn, init, xs, dim, reverse) + + +scan_op.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(scan_op, deferred_error=True) +) + + +@scan_op.py_impl(ProxyTorchDispatchMode) +def scan_proxy_mode(mode, combine_fn, init, xs, dim, reverse): + return trace_scan(mode, scan_op, combine_fn, init, xs, dim, reverse) + + +@scan_op.py_impl(FakeTensorMode) +def scan_fake_tensor_mode(mode, combine_fn, init, xs, dim, reverse): + with mode: + dim_len = xs[0].shape[dim] + carry, outputs = combine_fn( + *init, *[aten.slice(inp, dim, 0, 1, 1) for inp in xs] + ) + fake_out_shapes = [ + tuple(-1 if i != dim else dim_len for i, sh in enumerate(o.size())) + for o in outputs + ] + out = ( + carry, + tuple(t.expand(*sh).clone() for t, sh in zip(outputs, fake_out_shapes)), + ) + return out + + +@scan_op.py_functionalize_impl +def scan_functionalize(ctx, combine_fn, init, xs, dim, reverse): + unwrapped_xs = ctx.unwrap_tensors(xs) + unwrapped_init = ctx.unwrap_tensors(init) + with ctx.redispatch_to_next() as m: + functional_combine_fn = ctx.functionalize(combine_fn) + pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch + sample_xs = list(itertools.chain(unwrapped_init, unwrapped_init)) + if _has_potential_branch_input_mutation( + functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "Combine_fn might be modifying the input!" + ) + if _has_potential_branch_input_alias( + functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch + ): + raise UnsupportedAliasMutationException( + "Combine_fn might be aliasing the input!" + ) + ret = scan_op(functional_combine_fn, unwrapped_init, unwrapped_xs, dim, reverse) + return ctx.wrap_tensors(ret) diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index c58eea5750681d..7324e20dcd4cd7 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -32,6 +32,9 @@ class StrictMode(HigherOrderOperator): def __init__(self): super().__init__("strict_mode") + def __call__(self, callable, operands): + return super().__call__(callable, operands) + strict_mode_op = StrictMode() diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py index ef2f228bc4b3fe..b35b8d5b296d11 100644 --- a/torch/_higher_order_ops/torchbind.py +++ b/torch/_higher_order_ops/torchbind.py @@ -26,6 +26,9 @@ class CallTorchBind(HigherOrderOperator): def __init__(self): super().__init__("call_torchbind") + def __call__(self, obj, method, *args, **kwargs): + return super().__call__(obj, method, *args, **kwargs) + call_torchbind = CallTorchBind() diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 65de706a529142..a0548a1c8ee125 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -158,15 +158,13 @@ def generate_ttir(kernel, kwargs): ] specialization = kernel._get_config(*ordered_args.values()) constants = { - i: arg - for i, arg in enumerate(ordered_args.values()) - if not isinstance(arg, Tensor) + name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor) } # Build kernel signature -- doesn't include constexpr arguments. signature = { - i: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(ordered_args.values()) + name: kernel._type_of(kernel._key_of(arg)) + for i, (name, arg) in enumerate(ordered_args.items()) if i not in kernel.constexprs } @@ -179,13 +177,18 @@ def generate_ttir(kernel, kwargs): src = ASTSource(kernel, signature, constants, specialization) - # Triton changes ASTSource.make_ir to take 3 arguments. Handle + # Triton changes ASTSource.make_ir to take 3/4 arguments. Handle # backward compatibility here. - if len(inspect.signature(src.make_ir).parameters) == 2: + make_ir_sig_params = len(inspect.signature(src.make_ir).parameters) + if make_ir_sig_params == 2: ttir_module = src.make_ir(options, context) - else: + elif make_ir_sig_params == 3: codegen_fns = backend.get_codegen_implementation() ttir_module = src.make_ir(options, codegen_fns, context) + else: + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + ttir_module = src.make_ir(options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") @@ -522,6 +525,14 @@ class TritonKernelWrapperMutation(HigherOrderOperator): def __init__(self) -> None: super().__init__("triton_kernel_wrapper_mutation") + def __call__(self, kernel_idx, constant_args_idx, grid, kwargs): + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + ) + triton_kernel_wrapper_mutation = TritonKernelWrapperMutation() @@ -531,6 +542,15 @@ class TritonKernelWrapperFunctional(HigherOrderOperator): def __init__(self) -> None: super().__init__("triton_kernel_wrapper_functional") + def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone): + return super().__call__( + kernel_idx=kernel_idx, + constant_args_idx=constant_args_idx, + grid=grid, + kwargs=kwargs, + tensors_to_clone=tensors_to_clone, + ) + triton_kernel_wrapper_functional = TritonKernelWrapperFunctional() @@ -604,19 +624,23 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( return None +def get_mutated_tensors(kernel_idx, constant_args_idx, kwargs): + kernel = kernel_side_table.get_kernel(kernel_idx) + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + return identify_mutated_tensors(kernel, {**kwargs, **constant_args}) + + @triton_kernel_wrapper_mutation.py_functionalize_impl def triton_kernel_wrapper_mutation_functionalize( ctx, kernel_idx, constant_args_idx, grid, kwargs ): unwrapped_kwargs = ctx.unwrap_tensors(kwargs) - kernel = kernel_side_table.get_kernel(kernel_idx) - constant_args = kernel_side_table.get_constant_args(constant_args_idx) # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each # other, and one gets mutated in kernel, and later another gets mutated, # they are no longer equal. Fix this by graph breaking on this condition # earlier in dynamo. - tensors_to_clone = identify_mutated_tensors( - kernel, {**unwrapped_kwargs, **constant_args} + tensors_to_clone = get_mutated_tensors( + kernel_idx, constant_args_idx, unwrapped_kwargs ) with ctx.redispatch_to_next(): unwrapped_outputs = triton_kernel_wrapper_functional( diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index b0729b25ab64e5..139e9a160cbe20 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -105,21 +105,13 @@ def _maybe_reenter_make_fx(fn): @contextmanager def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag - _old_is_inlining = torch._dynamo.config.inline_inbuilt_nn_modules try: # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False - - # TODO(anijain2305, export-team) For non-strict export with module - # stack info, the codepatch forces the nn module __getattr__ to - # ProxyAttr __getattr__ downstream. To circumvent the issue for now, - # skip inlining inbuilt nn modules for cond. - torch._dynamo.config.inline_inbuilt_nn_modules = False yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing - torch._dynamo.config.inline_inbuilt_nn_modules = _old_is_inlining def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): @@ -232,6 +224,7 @@ def _from_fun(t): t.stride(), dtype=t.dtype, requires_grad=t.requires_grad, + device=t.device, ) else: # clone of a functional tensor produces a functional tensor diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 87764a826d9276..f14321842f40b9 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -7,6 +7,7 @@ from torch._higher_order_ops.utils import ( _has_potential_branch_input_alias, _has_potential_branch_input_mutation, + _maybe_run_with_interpreter, _set_compilation_env, autograd_not_implemented, reenter_make_fx, @@ -14,7 +15,11 @@ ) from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, + ProxyTorchDispatchMode, + track_tensor_tree, +) class WhileLoopOp(HigherOrderOperator): @@ -112,6 +117,9 @@ def body_fn(iter, x): - 'while_loop' only supports **inference** right now. Autograd will be supported in the future. """ + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) # Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo. # parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs. @@ -139,9 +147,15 @@ def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit(): - return torch.compile(_while_loop_op_wrapper, backend="eager", fullgraph=True)( - cond_fn, body_fn, carried_inputs, additional_inputs - ) + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode(metadata_mode) + else: + backend = "eager" + return torch.compile( + _while_loop_op_wrapper, backend=backend, fullgraph=True + )(cond_fn, body_fn, carried_inputs, additional_inputs) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) @@ -238,8 +252,8 @@ def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs): unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs with ctx.redispatch_to_next() as m: - functional_cond_fn = ctx.functionalize(cond_fn) - functional_body_fn = ctx.functionalize(body_fn) + functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn)) + functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn)) pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch for fn, fn_name in [ (functional_cond_fn, "cond_fn"), diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index f95e7caaf71e95..404869debf1766 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -30,6 +30,48 @@ def compile( return compile_fx(gm, example_inputs, config_patches=options) +def aoti_compile_and_package( + exported_program, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + package_path: Optional[str] = None, + inductor_configs: Optional[Dict[str, Any]] = None, +) -> str: + """ + Compiles the exported program with AOTInductor, and packages it into a .pt2 + file specified by the input package_path. + """ + from torch._inductor.package import package_aoti + from torch.export import ExportedProgram + + if not isinstance(exported_program, ExportedProgram): + raise ValueError("Only ExportedProgram is supported") + + assert package_path is None or package_path.endswith(".pt2") + + inductor_configs = inductor_configs or {} + + if inductor_configs.get("aot_inductor.output_path"): + raise RuntimeError( + "Please pass in a package path to aot_inductor_compile() instead " + "of setting the aot_inductor.output_path config." + ) + inductor_configs["aot_inductor.package"] = True + + m = exported_program.module() + assert isinstance(m, torch.fx.GraphModule) + + aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type] + + if package_path is None: + package_path = aoti_files + ".pt2" + + res = package_aoti(package_path, aoti_files) + assert res == package_path + return package_path + + def aot_compile( gm: torch.fx.GraphModule, args: Tuple[Any], diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 1ba70135bd7f8f..26759df1c59c17 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -7,6 +7,7 @@ import os import sys from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool from functools import partial from time import time from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING @@ -117,6 +118,16 @@ def after_fork(): pass # register_at_fork does not exists on windows +def get_worker_start_method() -> str: + """ + Temporary for internal subprocess pool rollout. Assign config.worker_start_method + lazily and return it. TODO: remove after rollout. + """ + if config.worker_start_method is None: + config.worker_start_method = config.decide_worker_start_method() + return config.worker_start_method + + class AsyncCompile: def __init__(self) -> None: pass @@ -127,17 +138,22 @@ def pool() -> ThreadPoolExecutor: assert config.compile_threads > 1 return ThreadPoolExecutor(config.compile_threads) + @staticmethod + def _get_ready(): + """No-op function to help mark when the subprocess pool is ready.""" + return "ready" + @staticmethod @functools.lru_cache(1) def process_pool() -> AnyPool: assert config.compile_threads > 1 pool: AnyPool - if config.worker_start_method == "subprocess": + if get_worker_start_method() == "subprocess": # Wrapper around ProcessPoolExecutor forks in a new process we control pool = SubprocPool(config.compile_threads) else: pre_fork_setup() - ctx = multiprocessing.get_context(config.worker_start_method) + ctx = multiprocessing.get_context(get_worker_start_method()) pool = ProcessPoolExecutor( config.compile_threads, mp_context=ctx, @@ -149,6 +165,8 @@ def process_pool() -> AnyPool: # kill the worker thread that sends the shutdown message to the workers... multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + # Set an attribute we can check to see if the pool is ready. + pool.ready_future = pool.submit(AsyncCompile._get_ready) # type: ignore[union-attr] _pool_set.add(pool) return pool @@ -166,13 +184,19 @@ def submit(cls, task: Callable[..., Any]) -> Any: return task() return cls.pool().submit(task) + def _use_process_pool(self): + return ( + config.compile_threads > 1 + and self.process_pool().ready_future.done() # type: ignore[union-attr] + ) + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): kernel_code_log.info("Triton Kernel:\n%s", source_code) _compile_start() _set_triton_ptxas_path() kernel = TritonCodeCache.load(kernel_name, source_code) - if config.compile_threads > 1: + if self._use_process_pool(): # We want to support changing these env vars after (and while) the # process pool is running, so pass them to the subprocess to reset. env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] @@ -258,7 +282,15 @@ def wait(self, scope: Dict[str, Any]) -> None: if config.verbose_progress and not isinstance(pbar, _Faketqdm): pbar.set_postfix_str(key) if isinstance(result, (Future, CodeCacheFuture)): - scope[key] = result.result() + try: + scope[key] = result.result() + except BrokenProcessPool as e: + raise RuntimeError( + "A compilation subprocess exited unexpectedly. This " + "is likely due to a crash. To facilitate debugging, " + "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " + "to cause compilation to occur in the main process." + ) from e pbar.update(1) _compile_end() @@ -269,6 +301,9 @@ def wait(self, scope: Dict[str, Any]) -> None: or os.environ.get("TORCH_WARM_POOL", "1") != "1" # The subprocess pool is only used for the Triton backend or not has_triton_package() + # Skip for fbcode so we can query the worker_start_method lazily. + # TODO: remove once "subprocess" has rolled out internally. + or config.is_fbcode() ): pass else: diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index eea4f8d6573d82..94efbaf4e32fdd 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -825,14 +825,14 @@ def precompile(self): # Prepopulate CppCodeCache # may happen in separate Threadpool log.debug("Precompiling %s", self) - CppCodeCache.load(self.source_code, cuda=False) + CppCodeCache.load(self.source_code, device_type="cpu") log.debug("Done precompiling %s", self) def make_run_fn( self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor ) -> Callable[[], None]: # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf - self.DLL = CppCodeCache.load(self.source_code, cuda=False) + self.DLL = CppCodeCache.load(self.source_code, device_type="cpu") args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] log.debug( "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index c2b0fe201fd685..7452f2bb1b62b6 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -9,7 +9,7 @@ import torch from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges -from .ir import InterpreterShim, LoopBody, LoopBodyBlock +from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock from .utils import cache_on_self, dominated_nodes from .virtualized import V diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 59e0e8ba9daf5f..135ebf96a86b69 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -53,13 +53,21 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor -from torch._dynamo.utils import counters, dynamo_timed, get_chromium_event_logger +from torch._dynamo.utils import ( + add_remote_cache_time_saved, + counters, + dynamo_timed, + get_chromium_event_logger, +) from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env from torch._inductor.codegen.rocm.compile_command import ( rocm_compile_command, rocm_compiler, ) +from torch._utils_internal import log_cache_bypass + +from .utils import _align T = TypeVar("T") @@ -68,7 +76,7 @@ if TYPE_CHECKING: from collections.abc import KeysView - from .remote_cache import RemoteCacheBackend + from .remote_cache import JsonDataTy, RemoteCache """ @@ -80,7 +88,7 @@ _transform_cuda_paths, CppBuilder, CppOptions, - CppTorchCudaOptions, + CppTorchDeviceOptions, get_compiler_version_info, get_cpp_compiler, get_name_and_dir_from_output_file_path, @@ -142,16 +150,16 @@ ) else: - def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] pass - def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] pass - def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] pass - def use_global_cache() -> bool: + def use_global_cache() -> bool: # type: ignore[misc] return False @@ -443,7 +451,7 @@ def write_text(text: str) -> str: def write_atomic( - path: str, + path_: str, content: Union[str, bytes], make_dirs: bool = False, encode_utf_8: bool = False, @@ -453,7 +461,7 @@ def write_atomic( assert isinstance( content, (str, bytes) ), "Only strings and byte arrays can be saved in the cache" - path = Path(path) + path = Path(path_) if make_dirs: path.parent.mkdir(parents=True, exist_ok=True) tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" @@ -530,7 +538,7 @@ def _reduce_tensor( # TODO: These tensors don't currently pickle, so we can't cache a # compiled graph containing them. Just fail now. If mkldnn tensors # get pickling support, we can remove this. - raise BypassFxGraphCache("mkldnn tensors unpickleable.") + raise BypassFxGraphCache("mkldnn tensors unpickleable") # Very large tensors could be expensive to copy to cpu and hash. Let's # at least report if we find slowness. @@ -561,7 +569,7 @@ def _reduce_unsupported(s: Any) -> NoReturn: See FxGraphCachePickler. Custom reducer to handle any objects that we don't support and therefore raise to bypass caching. """ - raise BypassFxGraphCache("Reduce unsupported.") + raise BypassFxGraphCache("Reduce unsupported") class FxGraphCachePickler(pickle.Pickler): @@ -599,7 +607,7 @@ def dumps(cls, obj: Any) -> bytes: # Some configs options are callables, e.g., post_grad_custom_pre_pass, # and may not pickle. log.warning("Can't pickle", exc_info=True) - raise BypassFxGraphCache("Config options may be unpickleable.") from e + raise BypassFxGraphCache("Config options may be unpickleable") from e return stream.getvalue() @classmethod @@ -694,7 +702,7 @@ def get_code_hash(root: str) -> bytes: from libfb.py import parutil - return parutil.get_file_contents("torch/src_hash.txt").rstrip() + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") def get_inductor_root() -> str: @@ -837,8 +845,10 @@ def cudagraph_post_compile( from .compile_fx import cudagraphify + current_callable = compiled_graph.current_callable + assert current_callable is not None compiled_graph.current_callable = cudagraphify( - compiled_graph.current_callable, + current_callable, static_input_idxs=static_input_idxs, device_index=next(iter(compiled_graph.device_idxs)), stack_traces=stack_traces, @@ -1000,7 +1010,7 @@ def _lookup_graph( key: str, example_inputs: List[torch.Tensor], local: bool, - remote_cache: Optional[Any], + remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> Optional[CompiledFxGraph]: """ Lookup a compiled graph in the cache by key. On a hit, return the @@ -1028,8 +1038,12 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: if remote_cache: try: - if (data := remote_cache.get(key)) is not None: - yield pickle.loads(data) + if (cache_data := remote_cache.get(key)) is not None: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, (str, bytes)) + content = base64.b64decode(data) + yield pickle.loads(content) except Exception: log.warning( "fx graph cache unable to load compiled graph", exc_info=True @@ -1179,7 +1193,7 @@ def _save_graph( compiled_graph: CompiledFxGraph, example_inputs: List[torch.Tensor], local: bool, - remote_cache: Optional[RemoteCacheBackend], + remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> None: """ Store a serialized CompiledFxGraph on disk. @@ -1227,15 +1241,11 @@ def _save_graph( if remote_cache: time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) - cache_data = ( - { - "data": content, - "time_taken_ms": time_taken_ms, - } - if config.is_fbcode() - else content - ) - remote_cache.put(key, cache_data) # type: ignore[arg-type] + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) except Exception: log.warning("fx graph unable to write to cache", exc_info=True) counters["inductor"]["fxgraph_cache_write_error"] += 1 @@ -1249,25 +1259,132 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: # Freezing can embed constants that wouldn't be static across runs. if config.freezing or config.aot_inductor.use_runtime_constant_folding: raise BypassFxGraphCache( - "Freezing may introduce constants that aren't static across runs." + "Freezing may introduce constants that aren't static across runs" ) # The treatment of guards in the caching implementation requires that # we have a shape env. if FxGraphCache._get_shape_env() is None: log.debug("fx graph cache no shape env") - raise BypassFxGraphCache("No shape env.") + raise BypassFxGraphCache("No shape env") # HigherOrderOperators should be handled on a case-by-case basis. # Currently, we just skip caching if we have any. # We also skip if there are any torchbind objects. for node in gm.graph.nodes: if isinstance(node.target, torch._ops.HigherOrderOperator): - raise BypassFxGraphCache("Can't cache HigherOrderOperators.") + raise BypassFxGraphCache( + f"Can't cache HigherOrderOperator: {node.target.name()}" + ) if node.op == "getattr" and isinstance( getattr(gm, node.target), torch._C.ScriptObject ): - raise BypassFxGraphCache("Can't cache torchbind objects.") + raise BypassFxGraphCache("Can't cache torchbind objects") + + @staticmethod + def prepare_key( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + fx_kwargs: Dict[str, Any], + inputs_to_check: Sequence[int], + remote: bool, + ) -> Tuple[Optional[Tuple[str, List[str]]], Dict[str, Any]]: + """ + Checks that the inductor input is cacheable, then computes + and returns the cache key for the input. + Returns (key_info, cache_info) where: + - key_info is (hash_key, debug_lines), and + - cache_info will contain debug info in the event of BypassFxGraphCache. + + NB: It is possible to have this function return a union instead. But + I personally believe it is more annoying/difficult to read in that format. + """ + try: + FxGraphCache._check_can_cache(gm) + key, debug_lines = compiled_fx_graph_hash( + gm, example_inputs, fx_kwargs, inputs_to_check + ) + except BypassFxGraphCache as e: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + log.info("Bypassing FX Graph Cache because '%s'", e) + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) + cache_info = { + "cache_state": "bypass", + "cache_bypass_reason": str(e), + "cache_event_time": time_ns(), + } + return None, cache_info + # If key exists, then cache_info will come from load_with_key + return (key, debug_lines), {} + + @staticmethod + def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: + """ + Attempts to load the remote cache, returns None on error. + """ + remote_cache = None + cache_id = "fx-graph-v1" + try: + if config.is_fbcode(): + from torch._inductor.fb.remote_cache import FbRemoteFxGraphCache + + remote_cache = FbRemoteFxGraphCache(cache_id) + else: + from torch._inductor.remote_cache import RemoteFxGraphCache + + remote_cache = RemoteFxGraphCache(cache_id) + except ModuleNotFoundError as e: + # No need for a stack trace on this error + remote_cache = None + log.warning("Unable to create a remote cache: %s", e) + except Exception: + remote_cache = None + log.warning("Unable to create a remote cache", exc_info=True) + return remote_cache + + @staticmethod + def load_with_key( + key: str, + debug_lines: List[str], + example_inputs: List[torch.Tensor], + local: bool, + remote_cache: Optional[RemoteCache[JsonDataTy]], + is_backward: bool, + ) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]: + """ + Lookup the graph with the given key, and return results and metadata. + Doesn't do any logging on its own, because AOTAutograd handles a cache miss + differently from FXGraphCache. + """ + compiled_graph = FxGraphCache._lookup_graph( + key, example_inputs, local, remote_cache + ) + cache_info = { + "key": key, + "components": debug_lines, + "cache_event_time": time_ns(), + } + if compiled_graph is not None: + log.debug("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + cache_info["cache_state"] = "hit" + + if (time_saved_ns := compiled_graph._time_taken_ns) is not None: + cache_info["time_saved_ns"] = time_saved_ns + add_remote_cache_time_saved(time_saved_ns, is_backward) + if ( + ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( + time_saved_ns + ) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + else: + log.debug("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + cache_info["cache_state"] = "miss" + + return compiled_graph, cache_info @staticmethod def load( # type: ignore[no-untyped-def] @@ -1285,86 +1402,70 @@ def load( # type: ignore[no-untyped-def] """ assert local or remote, "at least one of them needs to be enabled" compiled_graph = None - cache_state = None - cache_event_time = None - cache_info: Dict[str, Any] = {} - try: - FxGraphCache._check_can_cache(gm) - key, debug_lines = compiled_fx_graph_hash( - gm, example_inputs, fx_kwargs, inputs_to_check - ) - cache_info["key"] = key - cache_info["components"] = debug_lines - - remote_cache: Optional[RemoteCacheBackend] = None + remote_cache = None + (key_info, cache_info) = FxGraphCache.prepare_key( + gm, example_inputs, fx_kwargs, inputs_to_check, remote + ) + if key_info is not None: + key, debug_lines = key_info if remote: - cache_id = "fx-graph-v1" - try: - if config.is_fbcode(): - from torch._inductor.fb.remote_cache import ( - FbRemoteFxGraphCacheBackend, - ) - - remote_cache = FbRemoteFxGraphCacheBackend(cache_id) - else: - from torch._inductor.remote_cache import RedisRemoteCacheBackend + remote_cache = FxGraphCache.get_remote_cache() + compiled_graph, cache_info = FxGraphCache.load_with_key( + key, + debug_lines, + example_inputs, + local, + remote_cache, + is_backward=fx_kwargs.get("is_backward", False), + ) - remote_cache = RedisRemoteCacheBackend(cache_id) - except Exception: - remote_cache = None - log.warning("Unable to create a remote cache", exc_info=True) + # CACHE BYPASS: Compile the graph, don't save it to the cache + if cache_info["cache_state"] == "bypass": + assert compiled_graph is None + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) - compiled_graph = FxGraphCache._lookup_graph( - key, example_inputs, local, remote_cache + # CACHE MISS: Compile the graph and save to cache + elif cache_info["cache_state"] == "miss": + assert compiled_graph is None + assert key_info is not None + start_time = cache_info["cache_event_time"] + compiled_graph = compile_fx_fn( + gm, example_inputs, inputs_to_check, fx_kwargs + ) + compiled_graph._time_taken_ns = time_ns() - start_time + cache_key = key_info[0] + compiled_graph._fx_graph_cache_key = cache_key + cache_info["time_taken_ns"] = compiled_graph._time_taken_ns + FxGraphCache._save_graph( + cache_key, + compiled_graph, + example_inputs, + local, + remote_cache, ) + # CACHE HIT: not much to really do, just make sure the cache key + # is recorded on the graph + else: + assert cache_info["cache_state"] == "hit" + assert compiled_graph is not None + assert key_info is not None + cache_key = key_info[0] + compiled_graph._fx_graph_cache_key = cache_key - if compiled_graph is None: - log.debug("fx graph cache miss for key %s", key) - counters["inductor"]["fxgraph_cache_miss"] += 1 - cache_state = "miss" - start_time = time_ns() - cache_event_time = start_time - compiled_graph = compile_fx_fn( - gm, example_inputs, inputs_to_check, fx_kwargs - ) - compiled_graph._time_taken_ns = time_ns() - start_time - cache_info["time_taken_ns"] = compiled_graph._time_taken_ns - FxGraphCache._save_graph( - key, - compiled_graph, - example_inputs, - local, - remote_cache, - ) - else: - log.debug("fx graph cache hit for key %s", key) - counters["inductor"]["fxgraph_cache_hit"] += 1 - cache_state = "hit" - cache_event_time = time_ns() - if (time_saved_ns := compiled_graph._time_taken_ns) is not None: - cache_info["time_saved_ns"] = time_saved_ns - if ( - ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( - time_saved_ns - ) - ) != 0: - cache_info["ephemeral_timeout_increase"] = ephemeral_increase - compiled_graph._fx_graph_cache_key = key - except BypassFxGraphCache as e: - counters["inductor"]["fxgraph_cache_bypass"] += 1 - cache_state = "bypass" - log.info("Bypassing FX Graph Cache because '%s'", e) - cache_info["cache_bypass_reason"] = str(e) - cache_event_time = time_ns() - if not compiled_graph: - compiled_graph = compile_fx_fn( - gm, example_inputs, inputs_to_check, fx_kwargs - ) assert compiled_graph is not None - cache_info["cache_state"] = cache_state + + # Logging and observability: we log a single chromium event + # and a tlparse log for every cache action. + # In the event of a bypass, we also logged to the remote table earlier + # with log_cache_bypass. chromium_log = get_chromium_event_logger() + cache_state = cache_info["cache_state"] chromium_log.log_instant_event( - f"fx_graph_cache_{cache_state}", cache_event_time, metadata=cache_info + f"fx_graph_cache_{cache_state}", + cache_info["cache_event_time"], + metadata=cache_info, ) torch._logging.trace_structured( "artifact", @@ -1458,14 +1559,18 @@ def __init__( self.metrics_deltas = metrics_deltas self.counter_deltas = counter_deltas self.guards_expr = None + self.cudagraph_info = None + self.fx_kwargs = {} + self.inputs_to_check = () + self.boxed_forward_device_index = None def __call__(self, inputs: List[Any]) -> Any: assert self.current_callable is not None return self.current_callable(inputs) -def run_command_and_check(cmd: str) -> None: - cmd = shlex.split(cmd) +def run_command_and_check(cmd_: str) -> None: + cmd = shlex.split(cmd_) try: subprocess.check_call(cmd) except subprocess.CalledProcessError as e: @@ -1498,7 +1603,6 @@ def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> Non config.aot_inductor.output_path )[0], ) - params[get_cpp_wrapper_cubin_path_name()] = path cls.cache[key] = params @@ -1519,7 +1623,7 @@ def compile( graph: GraphLowering, source_code: str, serialized_extern_kernel_nodes: Optional[str], - cuda: bool, + device_type: str, ) -> str: if sys.platform == "win32": raise RuntimeError("AotCodeCompiler not yet supported for inductor") @@ -1530,9 +1634,9 @@ def compile( vec_isa_cmd_gen = CppBuilder( name="o", sources="i", - BuildOption=CppTorchCudaOptions( + BuildOption=CppTorchDeviceOptions( vec_isa=picked_vec_isa, - cuda=cuda, + device_type=device_type, aot_mode=graph.aot_mode, ), ) @@ -1547,7 +1651,7 @@ def compile( use_absolute_path = False if config.is_fbcode(): ld_command = build_paths.ld() - if not cuda and graph.aot_mode: # Meta internal AOTInductor CPU + if device_type == "cpu" and graph.aot_mode: # Meta internal AOTInductor CPU objcopy_command = build_paths.objcopy_fallback() fbcode_aot_cpu_re = True use_absolute_path = True @@ -1706,10 +1810,23 @@ def _compile_consts_darwin(consts: bytes) -> str: # Currently, this only support serializing extern nodes in fbcode # Eventually, we should also have a serializer for OSS. if serialized_extern_kernel_nodes: - output_json = os.path.splitext(input_path)[0] + ".json" - with open(output_json, "w") as f: + extern_kernel_nodes_json = os.path.splitext(input_path)[0] + ".json" + with open(extern_kernel_nodes_json, "w") as f: f.write(serialized_extern_kernel_nodes) + metadata = config.aot_inductor.metadata + metadata["AOTI_DEVICE_KEY"] = device_type + + # Save user provided metadata + meta_json = os.path.splitext(input_path)[0] + "_metadata.json" + for k, v in config.aot_inductor.metadata.items(): + assert isinstance(k, str) and isinstance( + v, (str) + ), "Metadata must only contain strings" + + with open(meta_json, "w") as f: + f.write(json.dumps(config.aot_inductor.metadata)) + output_so = ( config.aot_inductor.output_path if specified_so_name @@ -1717,10 +1834,23 @@ def _compile_consts_darwin(consts: bytes) -> str: ) output_o = os.path.splitext(input_path)[0] + ".o" + + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants.keys() + if name not in graph.folded_constants + ) + + def get_nbytes_of_tensor(tensor: torch.Tensor, all_cuda: bool) -> int: + n_bytes = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + return n_bytes if all_cuda else _align(n_bytes) + consts_size = sum( - torch.ops.mkldnn._nbytes(tensor) - if tensor.is_mkldnn - else tensor.untyped_storage().nbytes() + get_nbytes_of_tensor(tensor, all_cuda) for (name, tensor) in graph.constants.items() if name not in graph.folded_constants ) @@ -1729,54 +1859,29 @@ def _compile_consts_darwin(consts: bytes) -> str: if config.aot_inductor.force_mmap_weights: use_mmap_weights = True - if config.aot_inductor.package: - ( - object_output_name, - object_output_dir, - ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchCudaOptions( - vec_isa=picked_vec_isa, - cuda=cuda, - aot_mode=graph.aot_mode, - compile_only=True, - use_absolute_path=use_absolute_path, - use_mmap_weights=use_mmap_weights, - ) - object_builder = CppBuilder( - name=object_output_name, - sources=input_path, - output_dir=object_output_dir, - BuildOption=object_build_options, - ) - compile_cmd = object_builder.get_command_line() - output_o = object_builder.get_target_file_path() - - compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" - object_build_options.save_flags_to_file(compile_flags) - - else: - ( - object_output_name, - object_output_dir, - ) = get_name_and_dir_from_output_file_path(input_path) - object_build_options = CppTorchCudaOptions( - vec_isa=picked_vec_isa, - cuda=cuda, - aot_mode=graph.aot_mode, - compile_only=True, - use_absolute_path=use_absolute_path, - use_mmap_weights=use_mmap_weights, - ) - object_builder = CppBuilder( - name=object_output_name, - sources=input_path, - output_dir=object_output_dir, - BuildOption=object_build_options, - ) - compile_cmd = object_builder.get_command_line() - output_o = object_builder.get_target_file_path() + ( + object_output_name, + object_output_dir, + ) = get_name_and_dir_from_output_file_path(input_path) + object_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + compile_only=True, + use_absolute_path=use_absolute_path, + use_mmap_weights=use_mmap_weights, + ) + object_builder = CppBuilder( + name=object_output_name, + sources=input_path, + output_dir=object_output_dir, + BuildOption=object_build_options, + ) + compile_cmd = object_builder.get_command_line() + output_o = object_builder.get_target_file_path() - log.debug("aot compilation command: %s", compile_cmd) + log.debug("aot compilation command: %s", compile_cmd) + if not config.aot_inductor.package_cpp_only: if fbcode_aot_cpu_re: output_o = os.path.splitext(input_path)[0] + ".o" compile_file(input_path, output_o, compile_cmd.split()) @@ -1784,6 +1889,10 @@ def _compile_consts_darwin(consts: bytes) -> str: else: run_command_and_check(compile_cmd) + if config.aot_inductor.package: + compile_flags = os.path.splitext(input_path)[0] + "_compile_flags.json" + object_build_options.save_flags_to_file(compile_flags) + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: def _pad_to_alignment(raw_bytes: bytes) -> bytes: padded_bytes = raw_bytes.ljust( @@ -1814,11 +1923,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: raw_bytes = bytes(raw_array.contents) return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) - all_cuda = all( - graph.get_original_value_of_constant(name).is_cuda - for name in graph.constants.keys() - if name not in graph.folded_constants - ) serialized_weights = b"".join( _to_bytes(graph.get_original_value_of_constant(name), all_cuda) for name in graph.constants.keys() @@ -1838,29 +1942,37 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: "darwin": _compile_consts_darwin, }[sys.platform](aot_constants) - if config.aot_inductor.package: - output_name, output_dir = get_name_and_dir_from_output_file_path( - output_so - ) - so_build_options = CppTorchCudaOptions( - vec_isa=picked_vec_isa, - cuda=cuda, - aot_mode=graph.aot_mode, - use_absolute_path=use_absolute_path, - ) - so_builder = CppBuilder( - name=output_name, - sources=[output_o, consts_o], - output_dir=output_dir, - BuildOption=so_build_options, - ) - link_cmd = so_builder.get_command_line() - output_so = so_builder.get_target_file_path() + output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) + so_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_absolute_path=use_absolute_path, + ) + so_builder = CppBuilder( + name=output_name, + sources=[output_o, consts_o], + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) + # Append cmds to the end of codegen-ed wrapper file + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + if config.aot_inductor.package: linker_flags = os.path.splitext(input_path)[0] + "_linker_flags.json" so_build_options.save_flags_to_file(linker_flags) - from torch._inductor.package import package_aoti + if config.aot_inductor.package_cpp_only: + # If we only want to package the cpp, then we need to save the + # weights separately into a bin, and we also need to prevent compiling the so if use_mmap_weights: weight_file = ( @@ -1870,28 +1982,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: f_weights.write(serialized_weights) f_weights.write(struct.pack("q", magic_number)) - archive_path = package_aoti(os.path.split(input_path)[0]) - return archive_path else: - output_name, output_dir = get_name_and_dir_from_output_file_path( - output_so - ) - so_build_options = CppTorchCudaOptions( - vec_isa=picked_vec_isa, - cuda=cuda, - aot_mode=graph.aot_mode, - use_absolute_path=use_absolute_path, - ) - so_builder = CppBuilder( - name=output_name, - sources=[output_o, consts_o], - output_dir=output_dir, - BuildOption=so_build_options, - ) - link_cmd = so_builder.get_command_line() - output_so = so_builder.get_target_file_path() - - log.debug("aot linkage command: %s", link_cmd) if fbcode_aot_cpu_re: output_so = ( config.aot_inductor.output_path @@ -1904,18 +1995,22 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: run_command_and_check(link_cmd) if use_mmap_weights: + import resource + + page_size_ = resource.getpagesize() + page_size = max(16384, page_size_) + with open(output_so, "a+b") as f_so: so_size = f_so.tell() # Page align the weights - f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(b" " * (page_size - so_size % page_size)) f_so.write(serialized_weights) f_so.write(struct.pack("q", magic_number)) - # Append cmds to the end of codegen-ed wrapper file - with open(input_path, "a") as f: - f.write("\n") - f.write(f"// Compile cmd\n// {compile_cmd}\n") - f.write(f"// Link cmd\n// {link_cmd}\n") + if config.aot_inductor.package: + # We want to return the directory that contains all the AOTI + # generated files, not just the so + return os.path.split(output_so)[0] return output_so @@ -2039,7 +2134,9 @@ def convert_arg(arg: Any) -> Any: assert callable(func), op + " can not be loaded through custom_op_wrapper" result = func(*converted_args) if isinstance(result, (list, tuple)): - for r in result: + # unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only + result = [torch.tensor([]) if r is None else r for r in result] + for i, r in enumerate(result): assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] else: @@ -2084,13 +2181,13 @@ def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]: def load_async( cls, source_code: str, - cuda: bool = False, + device_type: str = "cpu", submit_fn: Any = None, extra_flags: Sequence[str] = (), ) -> Any: compile_command = { **cls.cpp_compile_command_flags, - "cuda": cuda, + "device_type": device_type, "vec_isa": pick_vec_isa(), "extra_flags": extra_flags, } @@ -2098,7 +2195,7 @@ def load_async( _set_gpu_runtime_env() # cpp_extension consults the env command_gen = CppBuilder( - name="o", sources="i", BuildOption=CppTorchCudaOptions(**compile_command) + name="o", sources="i", BuildOption=CppTorchDeviceOptions(**compile_command) ) # write function will calc source_code hash, the same source code with different # ISA level should be generate different hash. @@ -2121,7 +2218,7 @@ def load_async( future: Optional[Future[Any]] = None lib = None - cpp_build_option = CppTorchCudaOptions(**compile_command) + cpp_build_option = CppTorchDeviceOptions(**compile_command) cpp_builder = CppBuilder( name=output_name, sources=input_path, @@ -2164,8 +2261,8 @@ def load_fn() -> Any: return cls.cache[key] @classmethod - def load(cls, source_code: str, cuda: bool = False) -> Any: - return cls.load_async(source_code, cuda)() + def load(cls, source_code: str, device_type: str = "cpu") -> Any: + return cls.load_async(source_code, device_type)() def _worker_compile_cpp( @@ -2308,7 +2405,7 @@ def load_pybinding_async( cls, argtypes: List[str], source_code: str, - cuda: bool = False, + device_type: str = "cpu", num_outputs: int = -1, submit_fn: Any = None, extra_flags: Sequence[str] = (), @@ -2340,7 +2437,10 @@ def load_pybinding_async( cls.entry_function, ) get_result = cls.load_async( - source_code + suffix, cuda, submit_fn=submit_fn, extra_flags=extra_flags + source_code + suffix, + device_type, + submit_fn=submit_fn, + extra_flags=extra_flags, ) result = None @@ -2691,7 +2791,7 @@ def generate_halide_async( cls._codegen_glue(meta, headerfile), extra_flags=(libfile, cls.build_standalone_runtime()), submit_fn=jobs.append if need_compile else None, - cuda=meta.is_cuda(), + device_type="cuda" if meta.is_cuda() else "cpu", ) if need_compile: @@ -2719,9 +2819,9 @@ def build_standalone_runtime(cls) -> str: cls._standalone_runtime_path ): return cls._standalone_runtime_path - is_cuda = torch.cuda.is_available() + device_type = "cuda" if torch.cuda.is_available() else "cpu" libname = "libStandaloneHalideRuntime.so" - target = "host-cuda" if is_cuda else "host" + target = "host-cuda" if device_type == "cuda" else "host" if cls._standalone_runtime_path: assert not os.path.exists(cls._standalone_runtime_path) # We hit this case in unittests when we run with fresh_inductor_cache() @@ -2745,7 +2845,7 @@ def build_standalone_runtime(cls) -> str: with filelock.FileLock(lockfile, LOCK_TIMEOUT): if not os.path.exists(donefile): with open(hookfile, "w") as f: - if is_cuda: + if device_type == "cuda": f.write( cls.standalone_runtime_cuda_init.format( cls.find_header("HalideRuntimeCuda.h") @@ -2758,8 +2858,8 @@ def build_standalone_runtime(cls) -> str: name=name, sources=[hookfile, afile], output_dir=output_dir, - BuildOption=CppTorchCudaOptions( - cuda=is_cuda, + BuildOption=CppTorchDeviceOptions( + device_type=device_type, ), ) @@ -2931,7 +3031,7 @@ def _cuda_lib_options() -> List[str]: _set_gpu_runtime_env() # cpp_extension consults the env from torch.utils import cpp_extension - lpaths = cpp_extension.library_paths(cuda=True) + [ + lpaths = cpp_extension.library_paths(device_type="cuda") + [ sysconfig.get_config_var("LIBDIR") ] extra_ldflags: List[str] = [] diff --git a/torch/_inductor/codegen/codegen_device_driver.py b/torch/_inductor/codegen/codegen_device_driver.py deleted file mode 100644 index c31017fe6471cb..00000000000000 --- a/torch/_inductor/codegen/codegen_device_driver.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch - - -# Provide aoti module launch hip/cuda drivers. This file is also used for unit testing purpose - - -def cuda_kernel_driver() -> str: - source_codes = """ - #define CUDA_DRIVER_CHECK(EXPR) \\ - do { \\ - CUresult code = EXPR; \\ - const char *msg; \\ - cuGetErrorString(code, &msg); \\ - if (code != CUDA_SUCCESS) { \\ - throw std::runtime_error( \\ - std::string("CUDA driver error: ") + \\ - std::string(msg)); \\ - } \\ - } while (0); - - namespace { - - struct Grid { - Grid(uint32_t x, uint32_t y, uint32_t z) - : grid_x(x), grid_y(y), grid_z(z) {} - uint32_t grid_x; - uint32_t grid_y; - uint32_t grid_z; - - bool is_non_zero() { - return grid_x > 0 && grid_y > 0 && grid_z > 0; - } - }; - - } // anonymous namespace - - static inline CUfunction loadKernel( - std::string filePath, - const std::string &funcName, - uint32_t sharedMemBytes, - const std::optional &cubinDir = std::nullopt) { - if (cubinDir) { - std::filesystem::path p1{*cubinDir}; - std::filesystem::path p2{filePath}; - filePath = (p1 / p2.filename()).string(); - } - - CUmodule mod; - CUfunction func; - CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); - CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); - if (sharedMemBytes > 0) { - CUDA_DRIVER_CHECK(cuFuncSetAttribute( - func, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - sharedMemBytes - )) - } - return func; - } - - static inline void launchKernel( - CUfunction func, - uint32_t gridX, - uint32_t gridY, - uint32_t gridZ, - uint32_t numWarps, - uint32_t sharedMemBytes, - void* args[], - cudaStream_t stream) { - CUDA_DRIVER_CHECK(cuLaunchKernel( - func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr - )); - } - """ - if torch.version.hip is not None: - # Adjusting the warp size to GPU supported wavefront size on AMD GPU - prop = torch.cuda.get_device_properties(torch.cuda.current_device()) - source_codes = source_codes.replace( - "32*numWarps", str(prop.warp_size) + "*numWarps" - ) - return source_codes - - -def cuda_kernel_header() -> str: - source_codes = """ - #include - #include - #include - """ - return source_codes diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 4f5e65f5a25b25..95183c1484fe00 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -109,6 +109,42 @@ def synchronize(self): def device_guard(self, device_idx): raise NotImplementedError + def cpp_device_guard(self): + raise NotImplementedError + + def cpp_aoti_device_guard(self): + raise NotImplementedError + + def cpp_stream_guard(self): + raise NotImplementedError + + def cpp_aoti_stream_guard(self): + raise NotImplementedError + + def cpp_getStreamFromExternal(self): + raise NotImplementedError + + def kernel_header(self): + raise NotImplementedError + + def kernel_driver(self): + raise NotImplementedError + + def abi_compatible_header(self): + raise NotImplementedError + + def cpp_stream_type(self): + raise NotImplementedError + + def aoti_get_stream(self): + raise NotImplementedError + + def cpp_kernel_type(self): + raise NotImplementedError + + def cpp_device_ptr(self): + raise NotImplementedError + device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} @@ -196,7 +232,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): def init_backend_registration(): from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu - from .cpp_wrapper_cuda import CppWrapperCuda + from .cpp_wrapper_gpu import CppWrapperGpu from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .triton import TritonScheduling @@ -218,11 +254,15 @@ def init_backend_registration(): "cuda", lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs), WrapperCodeGen, - CppWrapperCuda, + CppWrapperGpu, ) if get_scheduling_for_device("xpu") is None: - register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen) + register_backend_for_device( + "xpu", + TritonScheduling, + WrapperCodeGen, + ) private_backend = torch._C._get_privateuse1_backend_name() if ( @@ -271,8 +311,8 @@ def get_device_op_overrides(device: str): @functools.lru_cache(None) def boolean_ops(): return ( - "is_inf", - "is_nan", + "isinf", + "isnan", "logical_not", "signbit", "le", @@ -344,6 +384,8 @@ def deduce_output_dtype_by_name( ): buf_name = args[1] return V.graph.get_dtype(buf_name) # type: ignore[arg-type] + elif op_name == "to_dtype_bitcast": + return kwargs["dtype"] if "dtype" in kwargs else args[-2] return None @@ -437,7 +479,7 @@ def propagate_loopbody(cls, body): @classmethod def propagate_scheduler_node(cls, node): - from ..ir import LoopBody + from ..loop_body import LoopBody from ..scheduler import SchedulerNode assert isinstance(node, SchedulerNode) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 89b4666b67fb15..8bbe8023a312af 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3,10 +3,10 @@ import dataclasses import functools import itertools -import logging import math import re import sys +import warnings from copy import copy, deepcopy from enum import Enum from typing import cast, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -17,13 +17,12 @@ import torch.fx from torch._inductor import dependencies from torch._prims_common import is_float_dtype, is_integer_dtype -from torch.utils import _pytree as pytree from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from ..._dynamo.utils import counters from .. import codecache, config, cpp_builder, cpu_vec_isa, ir, metrics -from ..codegen.wrapper import WrapperCodeGen +from ..loop_body import LoopBody from ..scheduler import ( BaseSchedulerNode, BaseScheduling, @@ -62,6 +61,8 @@ OptimizationContext, ) from .cpp_utils import ( + _get_dtype_from_loopbodies, + _get_loop_body, cexpr, cexpr_index, codegen_rand, @@ -133,6 +134,26 @@ def get_export_declaration(): torch.float16, ] +VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float64, + torch.float, + torch.bfloat16, + torch.float16, + torch.bool, + torch.uint8, + torch.int8, + torch.int32, + torch.int64, +] + +MASKED_VECTORIZABLE_DTYPES: List[torch.dtype] = [ + torch.float, + torch.bfloat16, + torch.float16, + torch.uint8, + torch.int8, +] + def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: @@ -294,9 +315,12 @@ def visit_modular_indexing(divisor, modulus): @functools.lru_cache -def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int): - index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length) - return stride_at(index_vec_simplified, var) +def stride_at_vec_range( + index: sympy.Expr, var: sympy.Symbol, vec_length: Optional[int] = None +): + if vec_length: + index = simplify_index_in_vec_range(index, var, vec_length) + return stride_at(index, var) class OuterLoopFusedSchedulerNode(FusedSchedulerNode): @@ -1012,6 +1036,7 @@ def wrapper(*args, **kwargs): "index_expr", ]: setattr(self, name, wrap(method.__func__)) + return self @staticmethod @@ -1540,8 +1565,137 @@ def index_expr(expr, dtype): csevar.update_on_args("index_expr", (expr, dtype), {}) return csevar + @staticmethod + def frexp(x): + cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" + if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys): + return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys) + + cdtype = DTYPE_TO_CPP[x.dtype] + size = V.kernel.tail_size if V.kernel.tail_size else V.kernel.tiling_factor + code = BracesBuffer() + exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar() + exponent.update_on_args("frexp", (x,), kwargs={}) + mantissa.update_on_args("frexp", (x,), kwargs={}) + n_vec = V.kernel._get_num_vectors(x.dtype) + mantissa_t = ( + f"at::vec::Vectorized<{cdtype}>" + if n_vec == 1 + else f"at::vec::VectorizedN<{cdtype}, {n_vec}>" + ) + code.writeline( + f"at::vec::Vectorized {exponent};" + if n_vec == 1 + else f"at::vec::VectorizedN {exponent};" + ) + code.writeline(f"{mantissa_t} {mantissa};") + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf;" + ) + code.writeline(f"{x}.store(tmpbuf.data(), {cexpr_index(size)});") + code.writeline( + f"__at_align__ std::array tmpbuf_exponent;" + ) + code.writeline( + f"__at_align__ std::array<{cdtype}, {V.kernel.tiling_factor}> tmpbuf_mantissa;" + ) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline( + "tmpbuf_mantissa[i] = std::frexp(tmpbuf[i], &tmpbuf_exponent[i]);" + ) + code.writeline( + f"{exponent} = at::vec::Vectorized::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + if n_vec == 1 + else f"{exponent} = at::vec::VectorizedN::loadu(tmpbuf_exponent.data(), {cexpr_index(size)});" + ) + code.writeline( + f"{mantissa} = {mantissa_t}::loadu(tmpbuf_mantissa.data(), {cexpr_index(size)});" + ) + code.writeline("();") + V.kernel.compute.splice(code) + cse_vars = (mantissa, exponent) + for cache_key, cse_var in zip(cache_keys, cse_vars): + V.kernel.cse.cache[cache_key] = cse_var + return mantissa, exponent + + @classmethod + def scalarize(cls, scalar_func): + def inner(*args, **kwargs): + assert not kwargs + kernel = V.kernel + assert isinstance(kernel, CppVecKernel) + code = BracesBuffer() + code.writeline("[&]()") + vec_dtype = args[0].dtype + n_vec = kernel._get_num_vectors(vec_dtype) + size = kernel.tail_size if kernel.tail_size else kernel.tiling_factor + scalar_args = [] + cdtype = DTYPE_TO_CPP[vec_dtype] + output_mask = scalar_func.__name__ in ( + "isinf", + "isnan", + "signbit", + ) + octype = "bool" if output_mask else cdtype + octype = ( + DTYPE_TO_CPP[args[-2]] + if (scalar_func.__name__ == "to_dtype_bitcast") + else octype + ) + with code.indent(): + for argidx, arg in enumerate(args): + if isinstance(arg, CppCSEVariable): + assert arg.is_vec + assert arg.dtype == vec_dtype + code.writeline( + f"__at_align__ std::array<{cdtype}, {kernel.tiling_factor}> tmpbuf{argidx};" + ) + code.writeline( + f"{arg}.store(tmpbuf{argidx}.data(), {cexpr_index(size)});" + ) + scalar_args.append(f"tmpbuf{argidx}[i]") + else: + scalar_args.append(arg) + code.writeline( + f"__at_align__ std::array<{octype}, {kernel.tiling_factor}> tmpbuf_out;" + ) + res = scalar_func(*scalar_args) + code.writeline(f"for (int i = 0; i < {cexpr_index(size)}; i++)") + with code.indent(): + code.writeline(f"tmpbuf_out[i] = {res};") + if output_mask: + assert not kernel.tail_size + load_args = "tmpbuf_out.data()" + load_fn = f"at::vec::VecMask<{cdtype},{n_vec}>::from" + else: + load_args = f"tmpbuf_out.data(), {cexpr_index(size)}" + if n_vec == 1: + load_fn = f"at::vec::Vectorized<{octype}>::loadu" + else: + load_fn = f" at::vec::VectorizedN<{octype}, {n_vec}>::loadu" + code.writeline(f"return {load_fn}({load_args});") + code.writeline("()") + return code + + return inner + + @classmethod + def _initialize_scalarize(cls): + for name, method in vars(CppOverrides).items(): + if getattr(method, "__class__", None) == staticmethod and name not in vars( + CppVecOverrides + ): + func = cls.scalarize(method.__func__) + func.__name__ = name + setattr(cls, name, staticmethod(func)) + CppVecOverrides._initialize_pointwise_overrides("cppvec") +CppVecOverrides._initialize_scalarize() class CppTile2DOverrides(CppVecOverrides): @@ -1992,9 +2146,7 @@ def codegen_loops(self, code, worksharing): @property def assert_function(self) -> str: - if V.graph.aot_mode: - # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models - # compared with JIT Inductor which uses TORCH_CHECK + if config.abi_compatible: return "AOTI_TORCH_CHECK" else: return "TORCH_CHECK" @@ -2146,7 +2298,7 @@ def _get_vec_load_line( line = ( f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})" if load_mask_str - else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.num_elems})" + else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {cexpr_index(self.num_elems)})" ) return line @@ -2157,6 +2309,7 @@ def _load_or_store_non_contiguous( dtype: torch.dtype, buffer: Optional[IndentedBuffer] = None, store_value: Optional[Union[str, CppCSEVariable]] = None, + accu_store: bool = False, ) -> Optional[CppCSEVariable]: """ Load or store a vector in a non-contiguous way. The vector is initialized from an array that is @@ -2171,10 +2324,12 @@ def _load_or_store_non_contiguous( :param dtype: data type of `var` or `index` if `var` is None. :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`. :param store_value: the value to store. If None, we load the vector. + :param accu_store: whether accumulate the store_value to store_ptr. If True, a store_value should be provided :return: a CppCSEVariable that represents the loaded vector or None if it is a store. """ assert not store_value or var is not None, "store var must be provided" - + if accu_store: + assert store_value if buffer is None: buffer = self.loads @@ -2184,6 +2339,12 @@ def get_result_size(dtype: torch.dtype) -> int: else: return self.num_elems + def get_tiling_size(dtype: torch.dtype) -> int: + if dtype.itemsize < 4: + return self.tiling_factor * (4 // dtype.itemsize) + else: + return self.tiling_factor + def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: assert vec_var.is_vec code = BracesBuffer() @@ -2194,10 +2355,11 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: if vec_dtype == torch.bool: vec_dtype = torch.float result_size = get_result_size(vec_dtype) + tiling_size = get_tiling_size(vec_dtype) code.writeline( - f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {result_size}> tmpbuf;" + f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {tiling_size}> tmpbuf;" ) - line = f"{vec_var}.store(tmpbuf.data(), {result_size});" + line = f"{vec_var}.store(tmpbuf.data(), {cexpr_index(result_size)});" code.writeline(line) code.writeline("return tmpbuf;") code.writeline("()") @@ -2209,12 +2371,15 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: code.writeline("[&]") with code.indent(): result_size = get_result_size(dtype) + tiling_size = get_tiling_size(dtype) result_declare = ( - f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;" + f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {tiling_size}> tmpbuf;" ) code.writeline(result_declare) if store_value: - code.writeline(f"{store_value}.store(tmpbuf.data(), {result_size});") + code.writeline( + f"{store_value}.store(tmpbuf.data(), {cexpr_index(result_size)});" + ) itervar_inner = sympy_index_symbol( f"{self.itervars[self.tiling_idx]}_inner" ) @@ -2240,12 +2405,12 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: else: load_mask = f"{self._load_mask} != 0" if cpp_builder.is_gcc(): - code.writeline(f"#pragma GCC unroll {self.num_elems}") + code.writeline(f"#pragma GCC unroll {self.tiling_factor}") else: - code.writeline(f"#pragma unroll {self.num_elems}") + code.writeline(f"#pragma unroll {self.tiling_factor}") code.writeline( f"for (long {itervar_inner} = 0; " - + f"{itervar_inner} < {self.num_elems}; " + + f"{itervar_inner} < {cexpr_index(self.num_elems)}; " + f"{itervar_inner}++)" ) with code.indent(), contextlib.ExitStack() as stack: @@ -2261,7 +2426,8 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: code.writeline(f"if ({load_mask})") stack.enter_context(code.indent()) if store_value: - code.writeline(f"{rhs} = tmpbuf[{itervar_inner}];") + conjunction = "+=" if accu_store else "=" + code.writeline(f"{rhs} {conjunction} tmpbuf[{itervar_inner}];") else: code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};") if not store_value: @@ -2304,6 +2470,7 @@ def _get_store_line( var: str, index: sympy.Expr, dtype: torch.dtype, + accu_store: bool = False, ): """ Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles @@ -2325,24 +2492,47 @@ def _get_store_line( if dtype == torch.float and self.tail_size is None: code.writeline(f"{value}.store({var_expr});") else: - code.writeline(f"{value}.store({var_expr}, {self.num_elems});") + code.writeline( + f"{value}.store({var_expr}, {cexpr_index(self.num_elems)});" + ) else: self._load_or_store_non_contiguous( - var, index, dtype, buffer=code, store_value=value + var, index, dtype, buffer=code, store_value=value, accu_store=accu_store ) return code def store(self, name, index, value, mode=None): assert "buf" in name - assert mode is None assert isinstance(value, CppCSEVariable), value if not value.is_vec: # this happens when we store a scalar into a vectorized buffer like "fill" value = self.broadcast(value) var = self.args.output(name) index = self.rename_indexing(index) - code = self._get_store_line(value, var, index, V.graph.get_dtype(name)) - self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + dtype = V.graph.get_dtype(name) + if mode is None: + code = self._get_store_line(value, var, index, dtype) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + elif mode == "atomic_add": + if not config.cpp.dynamic_threads and self.num_threads == 1: + code = self._get_store_line( + f"{value}", + var, + index, + dtype, + accu_store=True, + ) + self.stores.splice(code.map(lambda x: DeferredLine(name, x))) + else: + n_src = self._get_num_vectors(dtype) + n_idx = self._get_num_vectors(torch.int64) + cdtype = DTYPE_TO_CPP[dtype] + index = ops.index_expr(index, torch.int64).value + assert index.is_vec + line = f"atomic_add_vec<{cdtype}, {n_idx}, {n_src}>({var}, {index}, {value});" + self.stores.writeline(DeferredLine(name, line)) + else: + raise NotImplementedError(f"store mode={mode}") def reduction(self, dtype, src_dtype, reduction_type, value): assert reduction_type in VECTORIZABLE_RTYPES @@ -2488,9 +2678,9 @@ def reduction(self, dtype, src_dtype, reduction_type, value): ) is_bool = dtype == torch.bool # we are using at::vec::VecMask for bool - vec_dtype = "float" if is_bool else DTYPE_TO_CPP[dtype] - vec = f"at::vec::Vectorized<{vec_dtype}>" - vec_reduce_all_func = f"at::vec::vec_reduce_all<{vec_dtype}>" + vec_dtype = torch.float if is_bool else dtype + vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[vec_dtype]}>" + vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[vec_dtype]}, {self._get_num_vectors(vec_dtype)}>" next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})" self.reduction_suffix.writeline( @@ -2518,6 +2708,8 @@ def store_reduction(self, name, index, value): if out_dtype.is_floating_point else torch.int64 ) + out_num_vectors = V.kernel._get_num_vectors(out_dtype) + src_num_vectors = V.kernel._get_num_vectors(dtype) code = IndentedBuffer() if self.tiling_idx >= self.reduction_depth: # Horizontal reduction @@ -2528,9 +2720,19 @@ def store_reduction(self, name, index, value): # Vertical reduction if out_dtype != dtype: converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" - code.writeline( - f"auto {converted_value} = at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value});" - ) + if out_dtype == torch.bool: + convert = f"{value}.template cast()" + else: + if src_num_vectors == out_num_vectors == 1: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value})" + ) + else: + convert = ( + f"at::vec::convert<{DTYPE_TO_CPP[out_dtype]}," + f"{out_num_vectors},{DTYPE_TO_CPP[dtype]},{src_num_vectors}>({value})" + ) + code.writeline(f"auto {converted_value} = {convert};") value = converted_value code.splice(self._get_store_line(value, var, index, out_dtype)) self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x))) @@ -2640,7 +2842,7 @@ def reduction_combine_vec( is_bool = src_dtype == torch.bool if reduction_type == "max": if self.tail_size: - return f"max_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"max_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return ( f"{var} | {next_value}" @@ -2649,7 +2851,7 @@ def reduction_combine_vec( ) elif reduction_type == "min": if self.tail_size: - return f"min_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"min_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return ( f"{var} & {next_value}" @@ -2658,29 +2860,29 @@ def reduction_combine_vec( ) elif reduction_type == "sum": if self.tail_size: - return f"sum_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: conjunction = "|" if is_bool else "+" return f"{var} {conjunction} {next_value}" elif reduction_type == "prod": if self.tail_size: - return f"prod_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"prod_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return f"{var} * {next_value}" elif reduction_type == "xor_sum": if self.tail_size: - return f"xor_sum_masked_reduce({var}, {next_value}, {self.tail_size})" + return f"xor_sum_masked_reduce({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return f"{var} ^ {next_value}" elif reduction_type == "welford_reduce": if use_weight_recps: if self.tail_size: - return f"welford_combine({var}, {next_value}, {self.tail_size}, &{self.weight_recps_val})" + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)}, &{self.weight_recps_val})" else: return f"welford_combine({var}, {next_value}, &{self.weight_recps_val})" else: if self.tail_size: - return f"welford_combine({var}, {next_value}, {self.tail_size})" + return f"welford_combine({var}, {next_value}, {cexpr_index(self.tail_size)})" else: return f"welford_combine({var}, {next_value})" elif reduction_type == "welford_combine": @@ -2691,7 +2893,7 @@ def reduction_combine_vec( # When combining intermediate accumulators we have a Welford struct mean, m2, weight = reduction_project(reduction_type, next_value) if self.tail_size: - return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {self.tail_size})" + return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {cexpr_index(self.tail_size)})" else: return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})" elif reduction_type in ("argmin", "argmax"): @@ -2708,7 +2910,7 @@ def reduction_combine_vec( if self.tail_size: return ( f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" - f"({var}, {next_value}{arg_extra}, {self.tail_size})" + f"({var}, {next_value}{arg_extra}, {cexpr_index(self.tail_size)})" ) else: return f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>({var}, {next_value}{arg_extra})" @@ -2746,6 +2948,11 @@ def indirect_assert(self, var, lower, upper, mask=None): mask = f"{self._get_mask_type(var.dtype)}({mask})" # We need not check when the mask is False cond = f"({cond}) | ~({mask})" + if self.tail_size: + cond = ( + f"{self._get_mask_type(var.dtype)}::set({self._get_mask_type(var.dtype)}::from(1)" + f", ({cond}), {cexpr_index(self.tail_size)})" + ) cond = f"({cond}).all_masked()" return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")' @@ -2849,20 +3056,29 @@ def gen_transposed_tile_load_store(self, name, var, index, is_store): src = f"{var} + {cexpr_index(index)}" dst = "__place_holder__" ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}" - ld_dst = f"{self.num_elems}" + ld_dst = f"{cexpr_index(self.num_elems)}" if is_store: src, dst = dst, src ld_src, ld_dst = ld_dst, ld_src need_define = True if self.inner_is_tiling_idx ^ is_store: + M, N = self.inner_num_elems, self.outer_num_elems + else: + M, N = ( + self.outer_num_elems, + self.inner_num_elems, + ) + if (isinstance(M, sympy.Expr) and not M.is_number) or ( + isinstance(N, sympy.Expr) and not N.is_number + ): load_or_store = ( - f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{self.inner_num_elems},{self.outer_num_elems}>" - f"({src}, {ld_src}, {dst}, {ld_dst});" + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]}>" + f"({src}, {ld_src}, {dst}, {ld_dst}, {cexpr_index(M)}, {cexpr_index(N)});" ) else: load_or_store = ( - f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{self.outer_num_elems},{self.inner_num_elems}>" + f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{cexpr_index(M)},{cexpr_index(N)}>" f"({src}, {ld_src}, {dst}, {ld_dst});" ) if is_store: @@ -2874,7 +3090,7 @@ def gen_transposed_tile_load_store(self, name, var, index, is_store): tile_var = self.cse.cache[load_or_store] if need_define: - define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{self.outer_num_elems}*{self.inner_num_elems}];" + define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];" self.preloads.writeline(define_line) load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) @@ -2924,7 +3140,7 @@ def store(self, name, index, value, mode=None): torch.uint8, torch.int8, ]: - line = f"{value}.store({storebuf}, {self.num_elems});" + line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});" else: line = f"{value}.store({storebuf});" self.stores.writeline(DeferredLine(name, line)) @@ -2936,11 +3152,11 @@ def codegen_inner_loops(self, code): inner = self.inner_itervar() if self.inner_is_tiling_idx: code.writeline( - f"for (long {inner} = 0; {inner} < {self.outer_num_elems}; {inner}++)" + f"for (long {inner} = 0; {inner} < {cexpr_index(self.outer_num_elems)}; {inner}++)" ) else: code.writeline( - f"for (long {inner} = 0; {inner} < {self.inner_num_elems}; {inner}++)" + f"for (long {inner} = 0; {inner} < {cexpr_index(self.inner_num_elems)}; {inner}++)" ) def set_ranges(self, group, reduction_group): @@ -2969,248 +3185,16 @@ def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: ) -class CppVecKernelChecker(CppVecKernel): - def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1): - super().__init__(args, num_threads, tiling_factor, tiling_idx) - - # Since this kernel is only for checker but does not generate any - # code, so we need to decrease the kernel count. - metrics.generated_kernel_count -= 1 - - # Used to record the graph wrapper code as the wrapper_code status could be - # changed during graph run. - self._orig_wrapper_code = None - - self.simd_vec = True - - self.simd_masked_vec = True - - self.fast_vec_list = [] - for k, v in CppVecOverrides.__dict__.items(): - if isinstance(v, staticmethod): - self.fast_vec_list.append(k) - self.exit_stack = contextlib.ExitStack() - - # Cache all the load result - self.supported_dtypes: List[torch.dtype] = [ - torch.float64, - torch.float, - torch.bfloat16, - torch.float16, - torch.bool, - torch.uint8, - torch.int8, - torch.int32, - torch.int64, - ] - - # TODO: remove it after all data types support masked vectorization. - self.supported_dtypes_for_masked_vec: List[torch.dtype] = [ - torch.float, - torch.bfloat16, - torch.float16, - torch.uint8, - torch.int8, - ] - - def disable_vec(self, msg=None): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Disabled vectorization: %s", msg) - self.simd_vec = False - self.simd_masked_vec = False - - def disable_masked_vec(self, msg=None): - if schedule_log.isEnabledFor(logging.DEBUG): - schedule_log.debug("Disabled masked vectorization: %s", msg) - self.simd_masked_vec = False - - def load(self, name: str, index: sympy.Expr): - with RecordOptimizationContext(__name__) as node_ctx: - load_dtype = V.graph.get_dtype(name) - opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() - assert opt_ctx - - opt_ctx.dtype = load_dtype - var = self.cse.newvar() - - if load_dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{load_dtype} not supported by masked vectorization" - ) - - if has_free_symbols(self.ranges): - self.disable_masked_vec("Symbolic ranges not supported by masked load") - - if len(self.itervars) == 0: - self.disable_vec("not a loop") - return var - - if load_dtype not in self.supported_dtypes and ( - index.has(self.itervars[self.tiling_idx]) - or free_symbol_is_type(index, SymT.TMP) - ): - self.disable_vec(f"{load_dtype} not supported by load") - return var - - return var - - def store(self, name, index, value, mode=None): - with RecordOptimizationContext(__name__) as node_ctx: - store_dtype = V.graph.get_dtype(name) - - if store_dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{store_dtype} not supported by masked vectorization" - ) - - if has_free_symbols(self.ranges): - self.disable_masked_vec("Symbolic ranges not supported by masked store") - - if len(self.itervars) == 0: - self.disable_vec("not a loop") - return self.simd_vec - - opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() - assert opt_ctx - opt_ctx.dtype = store_dtype - - if store_dtype not in self.supported_dtypes: - self.disable_vec(f"{store_dtype} not supported by store") - return self.simd_vec - - assert "buf" in name - index = self.rename_indexing(index) - - if mode: - self.disable_vec(f"store mode: {mode}") - return self.simd_vec - - return self.simd_vec - - def reduction(self, dtype, src_dtype, reduction_type, value): - if has_free_symbols(self.ranges): - self.disable_masked_vec("Symbolic ranges not supported by masked reduction") - - if reduction_type not in VECTORIZABLE_RTYPES: - self.disable_vec( - f"reduction: dtype {dtype}, src_dtype {src_dtype}, reduction_type {reduction_type}" - ) - if is_welford_reduction(reduction_type): - return tuple([self.simd_vec] * 3) - return self.simd_vec - - def check_bounds( - self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ): - return self.simd_vec - - def store_reduction(self, name, index, value): - return self.simd_vec - - def __exit__(self, exc_type, exc_val, exc_tb): - # Restore the wrapper_code - V.graph.wrapper_code = self._orig_wrapper_code # type: ignore[assignment] - self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - - def __enter__(self): - # Record the graph wrapper code. The wrapper_code status could be - # changed during graph run. Regarding this checker, we also need to - # run the graph but we don't expect to change any status that would - # impact the code generation. Hence, we record the graph wrapper code - # and replace it with a dummy wrapper_code and then restore to the - # original one as long as the checker is finished. - self._orig_wrapper_code = V.graph.wrapper_code - V.graph.wrapper_code = WrapperCodeGen() - - parent_handler = V.MockHandler() - - class VecCheckerProxy: - @staticmethod - def __getattr__(name): # type: ignore[misc] - def inner(*args, **kwargs): - if name not in self.fast_vec_list: - self.disable_vec(f"op: {name}") - - parent_val = getattr(parent_handler, name)(*args, **kwargs) - return pytree.tree_map(lambda _: self.simd_vec, parent_val) - - return inner - - @staticmethod - def load(name: str, index: sympy.Expr): - return self.load(name, index) - - @staticmethod - def store(name, index, value, mode=None): - return self.store(name, index, value, mode=mode) - - @staticmethod - def reduction(dtype, src_dtype, reduction_type, value): - return self.reduction(dtype, src_dtype, reduction_type, value) - - @staticmethod - def store_reduction(name, index, value): - return self.store_reduction(name, index, value) - - @staticmethod - def check_bounds( - expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ): - return self.check_bounds(expr, size, lower, upper) - - @staticmethod - def constant(val, dtype): - with RecordOptimizationContext(__name__) as node_ctx: - opt_ctx: OptimizationContext = node_ctx.get_opt_ctx() - assert opt_ctx - - if opt_ctx.dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{opt_ctx.dtype} not supported by masked vectorization" - ) - - if opt_ctx.dtype not in self.supported_dtypes: - self.disable_vec(f"constant dtype: {opt_ctx.dtype}") - return val - - @staticmethod - def index_expr(expr, dtype): - return self.cse.newvar() - - @staticmethod - def indirect_indexing(index_var, size, check=True, wrap_neg=True): - return sympy_index_symbol(str(index_var)) - - @staticmethod - def masked(mask, body, other): - body() - return self.cse.newvar() - - @staticmethod - def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): - if dtype not in self.supported_dtypes_for_masked_vec: - self.disable_masked_vec( - f"{dtype} not supported by masked vectorization" - ) - - if dtype not in self.supported_dtypes: - self.disable_vec(f"to_dtype: {dtype}") - return x - - self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) - return self - - -def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Optional[torch.dtype]: +def get_loop_body_lowp_fp(_body: LoopBody) -> Tuple[Optional[torch.dtype], bool]: """ - Returns the low precision float data type (torch.float16/torch.bfloat16) if all the - nodes can codegen with this data type. Otherwise returns None. + Returns the low precision data type (torch.float16/torch.bfloat16) contained in the nodes + and if all the nodes can codegen with this data type without converting to float. + Otherwise returns None and True. """ sub_blocks = [_body.root_block] + list(_body.subblocks.values()) _lowp_fp_type: Optional[torch.dtype] = None - + _use_fp32 = False for sub_block in sub_blocks: for _node in sub_block.graph.nodes: if _node.op == "placeholder" or _node.target in ( @@ -3227,23 +3211,22 @@ def get_loop_body_lowp_fp(_body: ir.LoopBody) -> Optional[torch.dtype]: "neg", "output", ]: - return None + _use_fp32 = True if hasattr(_node, "meta") and _node.meta: assert OptimizationContext.key in _node.meta opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key] if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP: - return None - if _lowp_fp_type: - assert ( - _lowp_fp_type == opt_ctx.dtype - ), "do not support bf16/fp16 mix" + _use_fp32 = True + elif _lowp_fp_type is not None: + if _lowp_fp_type != opt_ctx.dtype: + warnings.warn("bf16 and fp16 are mixed in the scheduler node.") else: _lowp_fp_type = opt_ctx.dtype else: - return None + _use_fp32 = True - return _lowp_fp_type + return _lowp_fp_type, _use_fp32 class TilingSelect: @@ -3261,28 +3244,15 @@ def select_tiling( var_sizes_list, ) -> Tuple[List[int], List[int]]: # TODO(jgong5): support alternative tiling factors and data types - loop_bodies = None - if all(isinstance(fn, ir.LoopBody) for fn in fn_list): - loop_bodies = fn_list - else: - if hasattr(fn_list[0], "original_fn"): - # For the case of local buffer, we wrap the fn with localize_function - assert all(hasattr(fn, "original_fn") for fn in fn_list) - assert all( - isinstance(fn.original_fn.args[0]._body, ir.LoopBody) - for fn in fn_list - ) - loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] - else: - assert all(isinstance(fn, functools.partial) for fn in fn_list) - assert all(isinstance(fn.args[0]._body, ir.LoopBody) for fn in fn_list) - loop_bodies = [fn.args[0]._body for fn in fn_list] - assert loop_bodies is not None - + loop_bodies = _get_loop_body(fn_list) + all_dtypes = _get_dtype_from_loopbodies(loop_bodies) + assert all_dtypes + if any(dtype not in VECTORIZABLE_DTYPES for dtype in all_dtypes): + return [], [] dtype = torch.float - _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0]) + _lowp_fp_dtype = get_loop_body_lowp_fp(loop_bodies[0])[0] if _lowp_fp_dtype and all( - (get_loop_body_lowp_fp(loop_body) == _lowp_fp_dtype) + (get_loop_body_lowp_fp(loop_body)[0] == _lowp_fp_dtype) for loop_body in loop_bodies[1:] ): dtype = _lowp_fp_dtype @@ -3291,7 +3261,13 @@ def select_tiling( tiling_indices = self._select_tiling_indices( fn_list, var_sizes_list, tiling_factor ) + if tiling_indices: + group, reduction_group = max( + var_sizes_list, key=lambda sizes: len(sizes[1]) + ) + call_ranges = tuple(group) + tuple(reduction_group) + if config.cpp.enable_tiling_heuristics: def _try_get_stride( @@ -3327,10 +3303,6 @@ def _is_valid_indices( < len(itervars) ) - group, reduction_group = max( - var_sizes_list, key=lambda sizes: len(sizes[1]) - ) - call_ranges = tuple(group) + tuple(reduction_group) itervars = [ sympy_index_symbol_with_prefix(SymT.XBLOCK, n) for n in range(len(call_ranges)) @@ -3350,13 +3322,10 @@ def _is_valid_indices( for _node in sub_block.graph.nodes: if _node.target in ["index_expr", "load", "store"]: # get the index and replace prefix from z to x + arg_idx = 1 if _node.target == "index_expr" else 2 index = sub_block.body.indexing_from_args( (vars, reduction_vars) - )[ - _node.args[ - 1 if _node.target == "index_expr" else 2 - ].args[0] - ] + )[_node.args[arg_idx].args[0]] if _is_valid_indices(itervars, tiling_indices): stride = _try_get_stride( index, itervars, tiling_factor, tiling_indices @@ -3410,6 +3379,27 @@ def _is_valid_indices( # when needed. return [], [] + if dtype in DTYPE_LOWP_FP: + # For lower precision data type, if the call_range is not long enough, + # use tiling_factor // 2 for better performance + factor_lowp = cpu_vec_isa.pick_vec_isa().nelements(dtype=dtype) + for tiling_indice in tiling_indices: + if tiling_indice < 0: + tiling_indice = tiling_indice + len(call_ranges) + if tiling_indice < 0 or tiling_indice >= len(call_ranges): + continue + if has_free_symbols(call_ranges): + call_range = V.graph.sizevars.size_hint( + call_ranges[tiling_indice], fallback=0 + ) + if call_range < factor_lowp: + V.graph.sizevars.guard_lt(call_range, factor_lowp) + tiling_factor = factor_lowp // 2 + break + elif call_ranges[tiling_indice] < factor_lowp: + tiling_factor = factor_lowp // 2 + break + if len(tiling_indices) == 1: return [tiling_factor], tiling_indices if len(tiling_indices) == 2: @@ -3480,13 +3470,16 @@ def data_type_propagation(self, nodes): # Check if all the nodes of a given fx graph can support BF16/FP16 def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode): - if not isinstance(scheduler_node._body, ir.LoopBody): + if not isinstance(scheduler_node._body, LoopBody): return True # Propagate the dtype to check if all the fx node is bf16/fp16 DataTypePropagation.propagate_scheduler_node(scheduler_node) - return get_loop_body_lowp_fp(scheduler_node._body) is not None + return ( + get_loop_body_lowp_fp(scheduler_node._body)[0] is not None + and not get_loop_body_lowp_fp(scheduler_node._body)[1] + ) - def legalize_lowp_fp_dtype_loopbody(self, loop_body: ir.LoopBody): + def legalize_lowp_fp_dtype_loopbody(self, loop_body: LoopBody): def add_to_dtype(sub_graph: torch.fx.Graph): def is_lowp_fp_load(node: torch.fx.Node): if node.target not in ["load"]: @@ -3653,18 +3646,9 @@ def legalize_lowp_fp_dtype(self, nodes): for _node in nodes: assert isinstance(_node, SchedulerNode) - assert isinstance(_node._body, ir.LoopBody) - node: SchedulerNode = _node - - def is_memory_copy_scheduler_node(node: SchedulerNode): - op_counts = node.read_writes.op_counts - return ( - len(op_counts) == 2 and "load" in op_counts and "store" in op_counts - ) - - should_legalize = not is_memory_copy_scheduler_node(node) - if should_legalize: - body: ir.LoopBody = node._body + assert isinstance(_node._body, LoopBody) + body: LoopBody = _node._body + if not body.is_memory_copy(): self.legalize_lowp_fp_dtype_loopbody(body) def codegen_functions(self, fn_list, var_sizes_list): @@ -3711,6 +3695,10 @@ def run(kernel): if not self.picked_vec_isa: return + if not self.itervars: + # not a loop + return + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. # But the generated scalar kernel has updated these global contexts. Hence, the other kernels # should not do this again to avoid context conflict. By now, we only control the @@ -3722,29 +3710,11 @@ def run(kernel): ) assert len(tiling_factors) == len(tiling_indices) # This should be removed after full support for vectorization is implemented. - could_vec = True could_masked_vec = True - if tiling_factors and tiling_indices: - tiling_factor = tiling_factors[0] - assert all( - _tiling_factor == tiling_factor for _tiling_factor in tiling_factors - ) - for tiling_indice in tiling_indices: - with CppVecKernelChecker( - deepcopy(self.kernel_group.args), - parallel_num_threads(), - tiling_factor, - tiling_indice, - ) as vec_checker: - run(vec_checker) - could_vec = could_vec and vec_checker.simd_vec - could_masked_vec = ( - could_masked_vec and vec_checker.simd_masked_vec - ) - if not could_vec: - tiling_factors = [] - tiling_indices = [] - break + all_dtypes = _get_dtype_from_loopbodies(_get_loop_body(fn_list)) + if any(dtype not in MASKED_VECTORIZABLE_DTYPES for dtype in all_dtypes): + # can be removed after masked vectorizable dtype are same with vectorizable dtype + could_masked_vec = False if len(tiling_indices) == 1: vec_kernel = codegen_kernel( @@ -3756,7 +3726,7 @@ def run(kernel): ) main_loop.set_kernel(vec_kernel) main_loop.simd_vec = True - if could_masked_vec and (tail_loop.size - tail_loop.offset) >= 4: + if config.cpp.enable_loop_tail_vec and could_masked_vec: tail_loop.steps = tail_loop.size - tail_loop.offset masked_vec_kernel = codegen_kernel( CppVecKernel, @@ -3796,7 +3766,7 @@ def run(kernel): ) inner_main_loop.set_kernel(tile2d_kernel) - if could_masked_vec: + if config.cpp.enable_loop_tail_vec and could_masked_vec: ( inner_main_loop_of_outer_tail_loop, inner_tail_loop_of_outer_tail_loop, @@ -4183,6 +4153,113 @@ def can_fuse_vertical(self, node1, node2): self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) + def try_loop_split(self, nodes: List[SchedulerNode]): + """ + Apply loop split optimization. + When one of the indexing_exprs contains a division, we eliminate the division by splitting the loop + to avoid non-contiguous loads, subject to the following conditions: + 1. No reduction and no mudular index for all nodes. + 2. Only one node's one indexing_exprs contains a division, according to this indexing_exprs, + we can get the dimension that needs to be split, and the split dimension is contiguous + in all other indexing_exprs. + + For example, if the node's var_ranges: {z0: 2, z1: 9216, z2: 960} and indexing_exprs: + {'index0': 8847360*z0 + 960*z1 + z2, 'index1': 32*z0 + (z2//30), 'index2': z2}, + we will split z2 -> 30*z2 + z3, then the node's var_ranges will be changed to + {z0: 2, z1: 9216, z2: 32, z3: 30} and indexing_exprs will be changed to + {'index0': 8847360*z0 + 960*z1 + 30*z2 + z3, 'index1': 32*z0 + z2, 'index2': 30*z2 + z3}. + """ + + # No reduction and no mudular + if any( + len(node.group[1][1]) != 0 + or any( + expr.has(ModularIndexing) for expr in node._body.indexing_exprs.values() + ) + for node in nodes + ): + return nodes + + split_var = None + split_number = None + divide_index_name = None + num_div = 0 + match_div = False + matched_node = None + + for node in nodes: + assert isinstance(node.node, ir.ComputedBuffer) + _, original_body, _ = node.node.get_default_sizes_body() + for name, expr in original_body.indexing_exprs.items(): + num_div += expr.count(FloorDiv) + if num_div > 1: + return nodes + if expr.count(FloorDiv) == 1: + div_expr = expr.find(FloorDiv).pop() + split_var = div_expr.args[0] + split_number = div_expr.args[1] + divide_index_name = name + if ( + isinstance(split_number, sympy.core.numbers.Integer) + and isinstance(split_var, sympy.core.symbol.Symbol) + and split_var in original_body.iter_vars + and divide_index_name is not None + and all( + stride_at_vec_range(expr, split_var) == 1 + for name, expr in original_body.indexing_exprs.items() + if name != divide_index_name + ) + ): + match_div = True + matched_node = node + + # Only one node contains a division, and the split dimension is contiguous in all other indexing_exprs. + if not match_div: + return nodes + + extra_indexing_constraints = None + + def loop_split(sizes, body, vars): + index_size, reduce_size = sizes + index_vars, reduce_vars = vars + split_idx = index_vars.index(split_var) + new_index_size = index_size.copy() + new_index_size[split_idx] = index_size[split_idx] // split_number + new_index_size.insert(split_idx + 1, split_number) + (new_index_vars, _), var_ranges = dependencies.index_vars_no_squeeze( + new_index_size, reduce_size, prefix="y" + ) + iter_vars = new_index_vars.copy() + divisor_var = iter_vars.pop(split_idx + 1) + iter_vars[split_idx] = split_number * iter_vars[split_idx] + divisor_var + body = ir.LoopBody( + body, [iter_vars, reduce_vars], var_ranges, new_index_vars, reduce_vars + ) + nonlocal extra_indexing_constraints + if not extra_indexing_constraints: + extra_indexing_constraints = ( + body.var_ranges, + list(body.indexing_exprs.values()), + ) + return ( + (new_index_size, reduce_size), + body, + (new_index_vars, reduce_vars), + ) + + # Here decide the final loop order + for node in nodes: + if node == matched_node: + node.recompute_size_and_body(recompute_sizes_body_func=loop_split) + for node in nodes: + if node != matched_node: + node.recompute_size_and_body( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=loop_split, + ) + + return nodes + def codegen_outer_loop_node( self, node: OuterLoopFusedSchedulerNode, @@ -4258,9 +4335,9 @@ def is_all_write_read_contiguous(): ): contiguous_index_expr += stride * var stride *= range - write_index_expr = scheduler_node._body.writes_name2expr[ + write_index_expr = scheduler_node._body.get_write_expr( scheduler_buffer.get_name() - ] + ) def is_contiguous_index(x): return x == contiguous_index_expr @@ -4268,9 +4345,9 @@ def is_contiguous_index(x): return is_contiguous_index(write_index_expr) and all( isinstance(user.node, SchedulerNode) and is_contiguous_index( - user.node._body.reads_name2expr[ + user.node._body.get_read_expr( scheduler_buffer.get_name() - ], + ), ) for user in scheduler_buffer.users ) @@ -4385,6 +4462,7 @@ def codegen_node( self.codegen_outer_loop_node(node) else: nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + nodes = self.try_loop_split(nodes) cpp_kernel_proxy = CppKernelProxy(kernel_group) cpp_kernel_proxy.codegen_nodes(nodes) kernel_group.finalize_kernel(cpp_kernel_proxy, nodes) @@ -4484,7 +4562,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): compile_wrapper.splice(src_code, strip=True) if not V.graph.cpp_wrapper: compile_wrapper.writeline("''')") - wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False) + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), gpu=False) return kernel_name def flush(self): @@ -4565,7 +4643,7 @@ def codegen_group(self, name=None) -> str: def call_kernel(self, wrapper, kernel_name): _, call_args, arg_types = self.args.cpp_argdefs() wrapper.generate_kernel_call( - kernel_name, call_args, cuda=False, arg_types=arg_types + kernel_name, call_args, gpu=False, arg_types=arg_types ) @@ -4767,7 +4845,15 @@ def lines(self): line1 = "" offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}" size_str = f"{self.var}<{size_expr}" - steps_str = f"{self.var}+={cexpr_index(self.steps)}" + if self.steps.is_number: + steps_str = f"{self.var}+={cexpr_index(self.steps)}" + else: + # If the step size is 0, change it to 1 because a step size of 0 + # will cause floating point exception (core dump) during parallelization. + steps_str = ( + f"{self.var}+=({cexpr_index(self.steps)} == 0 ? " + f"1 : {cexpr_index(self.steps)})" + ) line2 = f"for({offset_str}; {size_str}; {steps_str})" if self.collapsed or not line1: return [line2] diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 6558e5ab5c8793..9756af21753078 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -58,18 +58,37 @@ const int64_t Mr_blocks = (M + Mr - 1) / Mr; {%- if num_threads > 1 %} int64_t Mt_blocks, Nt_blocks, Kt_blocks; - mm_get_thread_blocking(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); + mm_get_thread_blocking(num_threads, {{config.cpp.gemm_max_k_slices}}, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks); {%- else %} const auto Mt_blocks = Mr_blocks; const auto Nt_blocks = Nr_blocks; const auto Kt_blocks = Kr_blocks; {%- endif %} - const int64_t Mc_blocks = Mt_blocks; - const int64_t Nc_blocks = 1; - const int64_t Kc_blocks = Kt_blocks; + int64_t Mc_blocks, Nc_blocks, Kc_blocks; + uint32_t L1_cache_size = {{L1_cache_size}}; + uint32_t L2_cache_size = {{L2_cache_size}}; + mm_get_cache_blocking<{{kernel.dtype(X)}}, {{kernel.dtype(W)}}>( + num_threads, + M, + N, + K, + Mr, + Nr, + Kr, + Mt_blocks, + Nt_blocks, + Kt_blocks, + Mc_blocks, + Nc_blocks, + Kc_blocks, + L1_cache_size, + L2_cache_size + ); const int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; - const int64_t num_Nc_blocks = Nr_blocks; - const int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; + const int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; + const int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + const int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + const int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; {%- else %} constexpr int64_t M = {{kernel.size(GemmOut, 0)}}; constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; @@ -81,7 +100,9 @@ constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; - constexpr int64_t num_k_slices = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; + constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; + constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; + constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; {%- endif %} // make sure all partitions are assigned @@ -92,8 +113,8 @@ {%- if maybe_k_slicing %} std::unique_ptr[]> local_buf_ptrs; - if (num_k_slices > 1) { - local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_k_slices]); + if (num_Kt_blocks > 1) { + local_buf_ptrs.reset(new std::unique_ptr<{{DTYPE_TO_CPP[acc_buf_dtype]}}[]>[num_Mc_blocks * num_Nc_blocks * num_Kt_blocks]); } {%- endif %} @@ -101,33 +122,47 @@ #pragma omp parallel num_threads({{num_threads}}) { const int tid = omp_get_thread_num(); - int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; - mm_get_thread_blocks( - tid, Mr_blocks, Nr_blocks, Kr_blocks, Mt_blocks, Nt_blocks, Kt_blocks, - m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); - {%- if maybe_k_slicing %} - const int64_t k_group_id = tid / num_k_slices; - const int64_t k_slice_id = tid % num_k_slices; - {%- endif %} + const int64_t k_group_id = tid / num_Kt_blocks; + const int64_t k_slice_id = tid % num_Kt_blocks; + const int64_t n_group_id = k_group_id / num_Nt_blocks; + const int64_t n_slice_id = k_group_id % num_Nt_blocks; + const int64_t k_block_start = k_slice_id * Kt_blocks; + const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); + const int64_t n_block_start = n_slice_id * Nt_blocks; + const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); + const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); + const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); + const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; {%- else %} { - const int tid = 0; - const int64_t m_block_start = 0; - const int64_t m_block_end = Mr_blocks; - const int64_t n_block_start = 0; - const int64_t n_block_end = Nr_blocks; - const int64_t k_block_start = 0; - const int64_t k_block_end = Kr_blocks; + constexpr int tid = 0; + constexpr int64_t k_group_id = 0; + constexpr int64_t k_slice_id = 0; + constexpr int64_t n_group_id = 0; + constexpr int64_t n_slice_id = 0; + constexpr int64_t m_block_start = 0; + constexpr int64_t m_block_end = Mr_blocks; + constexpr int64_t n_block_start = 0; + constexpr int64_t n_block_end = Nr_blocks; + constexpr int64_t k_block_start = 0; + constexpr int64_t k_block_end = Kr_blocks; + {%- if is_dynamic_M %} + const int64_t num_Mc_blocks_per_thread = num_Mc_blocks; + {%- else %} + constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; + {%- endif %} {%- endif %} {{ micro_gemm.codegen_init(kernel) }} - for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { - const int64_t m_start = mc * Mr; - const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); - const int64_t m_size = m_end - m_start; {%- if use_local_acc %} {%- set acc_buf_name = "local_acc_buf" %} - {{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "Nc_blocks*Nr"], acc_buf_dtype) }} + {{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }} {%- endif %} + for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { + const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; + const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; + const int64_t m_start = mc * Mr; + const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); + const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); @@ -145,7 +180,7 @@ int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} for (int64_t nci = nc; nci < nc_block_end; nci++) { -{%- set acc_slice = kernel.slice_nd(acc, [(), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} +{%- set acc_slice = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")]) %} {%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %} {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} if (kc == k_block_start) { @@ -156,14 +191,15 @@ } } {%- if maybe_k_slicing %} - if (num_k_slices > 1) { + if (num_Kt_blocks > 1) { const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; - local_buf_ptrs[mxn_cache_block_id * num_k_slices + k_slice_id].reset({{ kernel.release_buffer(acc_buf_name) }}); + local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + k_slice_id].reset( + {{ kernel.release_buffer(acc_buf_name) }}); } else {%- endif %} { {%- set tile_Y = kernel.slice_nd(Y_2d, [("m_start", "m_end"), ("n_start", "n_end")]) %} -{%- set tile_acc = kernel.slice_nd(acc, [(), ("0", "n_end - n_start")]) %} +{%- set tile_acc = kernel.slice_nd(acc, [("0", "m_end - m_start"), ("0", "n_end - n_start")]) %} {{ kernel.store_output( tile_Y, tile_acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers )|indent(20, false) @@ -172,14 +208,14 @@ } } {%- if maybe_k_slicing %} - if (num_k_slices > 1) { + if (num_Kt_blocks > 1) { #pragma omp barrier for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { // We slice M-dim and each thread in the k-slicing group works on a slice const int64_t m_start_unsliced = mc * Mr; const int64_t m_end_unsliced = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size_unsliced = m_end_unsliced - m_start_unsliced; - const int64_t m_slice_size = (m_size_unsliced + num_k_slices - 1) / num_k_slices; + const int64_t m_slice_size = (m_size_unsliced + num_Kt_blocks - 1) / num_Kt_blocks; const int64_t m_start = std::min(m_start_unsliced + m_slice_size * k_slice_id, m_end_unsliced); const int64_t m_end = std::min(m_start_unsliced + m_slice_size * (k_slice_id + 1), m_end_unsliced); const int64_t m_size = m_end - m_start; @@ -189,9 +225,9 @@ const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; const int64_t mxn_cache_block_id = (mc / Mc_blocks) * num_Nc_blocks + nc; - auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_k_slices].get(); - for (int64_t other_slice = 1; other_slice < num_k_slices; other_slice++) { - auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_k_slices + other_slice].get(); + auto {{acc_buf_name}} = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks].get(); + for (int64_t other_slice = 1; other_slice < num_Kt_blocks; other_slice++) { + auto other_acc = local_buf_ptrs[mxn_cache_block_id * num_Kt_blocks + other_slice].get(); for (int64_t m = m_offset; m < m_offset + m_size; m++) { #pragma omp simd for (int64_t n = 0; n < n_size; n++) { @@ -377,7 +413,7 @@ def get_cache_blocking(register_blocking, thread_blocking): # The ratios below are empirically determined to decide # the effective sizes of L1 and L2. # TODO: tune the factor here - L1_limit_factor = 1 + L1_limit_factor = 0.8 L2_limit_factor = 0.5 L1_cache_size = ( @@ -531,7 +567,7 @@ def normalize_shapes(inputs, layout_or_out): new_inputs = list(inputs) X = inputs[0] W = inputs[1] - B = inputs[2] if len(inputs) > 2 else None + B = inputs[2] if has_bias else None if isinstance(W, ir.IRNode): if trans_w: if not isinstance(W, ir.TensorBox): @@ -680,8 +716,45 @@ def postprocessor(output): # non-retraceable. To support retracing, we can add a repack node to the # FX graph. For example: # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template + W_tensor_users = 0 + for node in reversed(V.graph.graph.nodes): + # Case may happen when the wgt tensor is used by more than 1 get_attr node + # https://github.com/pytorch/pytorch/issues/134998 + if node.op == "get_attr" and hasattr( + V.graph.module, node.name + ): # wgt might already be deleted + comp_tensor = getattr(V.graph.module, node.name) + if ( + W.is_mkldnn == comp_tensor.is_mkldnn + and W.dtype == comp_tensor.dtype + and W.device == comp_tensor.device + and ( + ( + not W.is_mkldnn + and ( + W.untyped_storage().data_ptr() + == comp_tensor.untyped_storage().data_ptr() + ) + ) + or ( + W.is_mkldnn + and ( + torch.ops.mkldnn.data_ptr(W) + == torch.ops.mkldnn.data_ptr(comp_tensor) + ) + ) + ) + ): + W_tensor_users += 1 + for node in reversed(V.graph.graph.nodes): - if node.name == W_node.get_name() and len(node.users) == 1: + # The wgt tensor has been used by only 1 get_attr node + # The get_attr node has only 1 user fx node + if ( + node.name == W_node.get_name() + and len(node.users) == 1 + and W_tensor_users == 1 + ): del V.graph.constants[node.name] delattr(V.graph.module, node.name) delattr(V.graph.graph.owning_module, node.name) @@ -759,6 +832,7 @@ def render( # type: ignore[override,return] use_local_acc = ( self.layout.dtype != torch.float + or template_buffer_has_other_users or int8_gemm or self.padded_n != self.n or self.maybe_k_slicing() @@ -874,38 +948,45 @@ def copy_inner(index): reindexers.extend([None] * len(epilogue_nodes)) Y_2d = Y else: - # From template_buffer to Y_ordered (ordered by stride decreasingly, in dense format), for example: - # template_buffer: - # size (324, 512), stride (512, 1) - # Y_ordered (ordered by stride decreasingly, in dense format): - # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) - stride_order = list( - ir.get_stride_order(V.graph.sizevars.size_hints(Y.get_stride())) - ) - fill_order = ir.stride_order2fill_order(stride_order) - reversed_fill_order = list(reversed(fill_order)) - size_with_stride_ordered_decreasingly = [ - Y.get_size()[i] for i in reversed_fill_order - ] - reshape_reindex = ir.View.dynamic_reshape_indexer( - size_with_stride_ordered_decreasingly, template_buffer.get_size() - ) - # From Y_ordered (ordered by stride decreasingly, in dense format) to Y, for example: - # Y_ordered (ordered by stride decreasingly, in dense format): - # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) - # Y: - # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) - from_stride_ordered_decreasingly_to_Y_order = [ - (len(stride_order) - 1) - stride_order[i] - for i in range(len(stride_order)) - ] - stride_reindex = ir.same_reorder( - from_stride_ordered_decreasingly_to_Y_order - ) + def get_reindexer(epilogue_node): + # From template_buffer to epilogue_node_ordered (ordered by stride decreasingly, in dense format), for example: + # template_buffer: + # size (324, 512), stride (512, 1) + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + stride_order = list( + ir.get_stride_order( + V.graph.sizevars.size_hints(epilogue_node.get_stride()) + ) + ) + fill_order = ir.stride_order2fill_order(stride_order) + reversed_fill_order = list(reversed(fill_order)) + size_with_stride_ordered_decreasingly = [ + epilogue_node.get_size()[i] for i in reversed_fill_order + ] + reshape_reindex = ir.View.dynamic_reshape_indexer( + size_with_stride_ordered_decreasingly, + template_buffer.get_size(), + ) + + # From epilogue_node_ordered (ordered by stride decreasingly, in dense format) to epilogue_node, for example: + # epilogue_node_ordered (ordered by stride decreasingly, in dense format): + # size (1, 18, 18, 512), stride (165888, 9216, 512, 1) + # epilogue_node: + # size (1, 18, 18, 512), stride (165888, 1, 9216, 512) + from_stride_ordered_decreasingly_to_epilogue_node_order = [ + (len(stride_order) - 1) - stride_order[i] + for i in range(len(stride_order)) + ] + stride_reindex = ir.same_reorder( + from_stride_ordered_decreasingly_to_epilogue_node_order + ) - reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) - reindexers.extend([reindexer] * len(epilogue_nodes)) # type: ignore[list-item] + reindexer = ir.fuse_reindexing(stride_reindex, reshape_reindex) + return reindexer + + reindexers.extend([get_reindexer(epilogue_node) for epilogue_node in epilogue_nodes]) # type: ignore[list-item] if isinstance(Y, ir.BaseView): storage = ir.StorageBox(Y.unwrap_view()) else: @@ -934,6 +1015,12 @@ def copy_inner(index): if isinstance(micro_gemm, CppMicroGemmAMX): counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1 + L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes + assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}" + + L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes + assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}" + options = dict( X=X, W=W, @@ -963,6 +1050,9 @@ def copy_inner(index): w_zp=w_zp, acc_buf_dtype=torch.int32 if int8_gemm else torch.float, DTYPE_TO_CPP=DTYPE_TO_CPP, + L1_cache_size=L1_cache_size, + L2_cache_size=L2_cache_size, + config=config, ) with contextlib.ExitStack() as stack: for buf in fake_buffers: diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index fb6d86d1a720f5..2f0c3e78f56791 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -332,8 +332,8 @@ class CppMicroGemmFP32Vec(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); @@ -364,7 +364,7 @@ class CppMicroGemmFP32Vec(CppMicroGemm): break; {%- endfor %} default: - {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + {{kernel.assert_function}}(false, "Unsupported block_m: {{block_m}}"); } } } @@ -509,8 +509,8 @@ class CppMicroGemmAMX(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { - TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); - TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); + {{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + {{kernel.assert_function}}(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); @@ -603,15 +603,11 @@ class CppMicroGemmAMX(CppMicroGemm): {%- if input_dtype == torch.bfloat16 and input2_dtype == torch.int8 %} // create a buffer for tiles of B. - // TODO: loop-unrolling of the "compute" lambda may result in incorrect output - // as this buffer would be used, so maybe 4 of these should be used? - // Since UT output is correct, looks like loop unrolling isn't actually happening. alignas(64) {{input_t}} bf16_weights_buf[512]; int num_b_rows = (last_k_offset > 0) ? 16 : (tail_k_size * sizeof({{input_t}})) / 4; int b_tile_ptr_stride = ldb * {{vnni_size}}; - // TODO: verify whether or not these lambdas inline auto load_B_row = [&]({{input2_t}}* src, {{input_t}}* dst) { {{kernel.unroll_pragma(2)}} for (int i = 0; i < 2; i++) { @@ -801,6 +797,14 @@ def create_from_config(cls, config: CppMicroGemmConfig): ): continue block_m, block_n, block_k = config.register_blocking + if ( + config.vec_isa_cls == VecAMX + and m < block_m + and input_dtype == torch.bfloat16 + and input2_dtype == torch.int8 + ): + # For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m + continue # Criteria on the ranking of configurations # 1. ISA: AMX > VEC # 2. Dividable by block sizes (block_m, block_n, block_k) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index e1937d26dac181..fd5c380cd7711e 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -7,6 +7,7 @@ #include #include #include +#include #include // WARNING: be extra careful when including more ATen/c10 header files here! @@ -25,6 +26,7 @@ #include #include #include +#include #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) #define INDUCTOR_USE_VECTOR_TYPES() 1 @@ -628,8 +630,54 @@ atomic_add(volatile T *addr, T offset) { atomic_addr->fetch_add(offset, std::memory_order_relaxed); } -void mm_get_thread_blocking( +#if INDUCTOR_USE_VECTOR_TYPES() +template +void atomic_add_vec(T *addr, at::vec::VectorizedN index, at::vec::VectorizedN offset) { + constexpr int len = at::vec::VectorizedN::size(); + static_assert(len <= at::vec::VectorizedN::size()); + __at_align__ std::array tmpbuf; + __at_align__ std::array tmpidx; + offset.store(tmpbuf.data()); + index.store(tmpidx.data()); + for (int i = 0; i < len; i++){ + atomic_add(addr + tmpidx[i], tmpbuf[i]); + } +} +#endif + +std::tuple, int> _get_factors(int64_t number) { + int count = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + count += 2; + } + } + auto factors = std::shared_ptr(new int64_t[count]); + int index = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + factors[index++] = number / i; + factors[index++] = i; + } + } + return std::make_tuple(factors, count); +} + +std::tuple, int> get_factors(int64_t number) { + thread_local std::map, int>> cache; + auto it = cache.find(number); + if (it != cache.end()) { + return it->second; + } else { + auto factors = _get_factors(number); + cache[number] = factors; + return factors; + } +} + +void _mm_get_thread_blocking( int num_threads, + int max_k_slices, int64_t M, int64_t N, int64_t K, @@ -640,36 +688,18 @@ void mm_get_thread_blocking( int64_t& Nt, int64_t& Kt) { // see NOTE [Thread blocking in Cpp GEMM] for heuristics - // TODO(jgong5): cache thread blocking results Mt = Nt = Kt = 0; - auto get_factors = [](int64_t number) { - int count = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - count += 2; - } - } - auto factors = std::make_unique(count); - int index = 0; - for (int64_t i = std::sqrt(number); i > 0; --i) { - if (number % i == 0) { - factors[index++] = number / i; - factors[index++] = i; - } - } - return std::make_tuple(std::move(factors), count); - }; - - auto get_blocking = [](int64_t num_threads, - int64_t factor, + auto get_blocking = [](int64_t m_factor, + int64_t n_factor, + int64_t k_factor, int64_t m_blocks, int64_t n_blocks, int64_t k_blocks) { - int64_t thread_block_n = (n_blocks + factor - 1) / factor; - int64_t cofactor = num_threads / factor; - int64_t thread_block_m = (m_blocks + cofactor - 1) / cofactor; - return std::make_tuple(thread_block_m, thread_block_n, k_blocks); + int64_t thread_block_k = (k_blocks + k_factor - 1) / k_factor; + int64_t thread_block_n = (n_blocks + n_factor - 1) / n_factor; + int64_t thread_block_m = (m_blocks + m_factor - 1) / m_factor; + return std::make_tuple(thread_block_m, thread_block_n, thread_block_k); }; auto is_better_blocking = [=](int64_t Mt_, @@ -678,7 +708,7 @@ void mm_get_thread_blocking( int64_t Mt, int64_t Nt, int64_t Kt) { - return Mt == 0 || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr; + return Mt == 0 || Kt_ < Kt || Mt_ * Mr + Nt_ * Nr < Mt * Mr + Nt * Nr; }; int64_t m_blocks = (M + Mr - 1) / Mr; @@ -689,11 +719,11 @@ void mm_get_thread_blocking( assert(count > 0); for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - if (n_blocks >= factor && - m_blocks >= num_threads / factor) { + int64_t n_factor = factors[i]; + int64_t m_factor = num_threads / n_factor; + if (n_blocks >= n_factor && m_blocks >= m_factor) { auto [Mt_, Nt_, Kt_] = get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); } @@ -705,11 +735,33 @@ void mm_get_thread_blocking( } for (int i = 0; i < count; ++i) { - int64_t factor = factors[i]; - int64_t cofactor = num_threads / factor; - if (n_blocks >= factor || m_blocks >= cofactor) { + int64_t k_factor = factors[i]; + if (k_blocks >= k_factor && (max_k_slices == 0 || k_factor <= max_k_slices)) { + auto [mxn_factors, mxn_count] = get_factors(num_threads / k_factor); + for (int j = 0; j < mxn_count; ++j) { + int64_t n_factor = mxn_factors[j]; + int64_t m_factor = num_threads / (k_factor * n_factor); + if (n_blocks >= n_factor && m_blocks >= m_factor) { + auto [Mt_, Nt_, Kt_] = get_blocking( + m_factor, n_factor, k_factor, m_blocks, n_blocks, k_blocks); + if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { + std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); + } + } + } + } + } + + if (Mt != 0) { + return; + } + + for (int i = 0; i < count; ++i) { + int64_t n_factor = factors[i]; + int64_t m_factor = num_threads / n_factor; + if (n_blocks >= n_factor || m_blocks >= m_factor) { auto [Mt_, Nt_, Kt_] = get_blocking( - num_threads, factor, m_blocks, n_blocks, k_blocks); + m_factor, n_factor, 1, m_blocks, n_blocks, k_blocks); if (is_better_blocking(Mt_, Nt_, Kt_, Mt, Nt, Kt)) { std::tie(Mt, Nt, Kt) = std::make_tuple(Mt_, Nt_, Kt_); } @@ -719,30 +771,118 @@ void mm_get_thread_blocking( assert(Mt != 0); } -inline void mm_get_thread_blocks( - int thread_id, - int64_t M_blocks, - int64_t N_blocks, - int64_t K_blocks, +void mm_get_thread_blocking( + int num_threads, + int max_k_slices, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t& Mt, + int64_t& Nt, + int64_t& Kt) { + thread_local std::map< + std::tuple, + std::tuple> cache; + auto key = std::make_tuple(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr); + auto it = cache.find(key); + if (it != cache.end()) { + std::tie(Mt, Nt, Kt) = it->second; + return; + } else { + _mm_get_thread_blocking(num_threads, max_k_slices, M, N, K, Mr, Nr, Kr, Mt, Nt, Kt); + cache[key] = std::make_tuple(Mt, Nt, Kt); + } +} + +template +void _mm_get_cache_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, int64_t Mt_blocks, int64_t Nt_blocks, int64_t Kt_blocks, - int64_t& m_block_start, - int64_t& m_block_end, - int64_t& n_block_start, - int64_t& n_block_end, - int64_t& k_block_start, - int64_t& k_block_end) { - int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; - k_block_start = (thread_id % num_Kt) * Kt_blocks; - k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); - thread_id /= num_Kt; - int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; - n_block_start = (thread_id % num_Nt) * Nt_blocks; - n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); - thread_id /= num_Nt; - m_block_start = std::min(thread_id * Mt_blocks, M_blocks); - m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); + int64_t& Mc_blocks, + int64_t& Nc_blocks, + int64_t& Kc_blocks, + uint32_t L1_cache_size, + uint32_t L2_cache_size) { + // See NOTE [CPP GEMM Cache Blocking Algorithm] for the cache blocking algorithm. + // TODO(jgong5): cache cache blocking results + // TODO: tune the factor here + float L1_limit_factor = 0.8; + float L2_limit_factor = 0.5; + + auto L1 = L1_cache_size * L1_limit_factor; + auto L2 = L2_cache_size * L2_limit_factor; + + constexpr size_t num_byte_A = sizeof(X_t); + constexpr size_t num_byte_B = sizeof(W_t); + + int64_t size_cache_B = Kr * Kt_blocks * Nr * num_byte_B; + Kc_blocks = Kt_blocks; + if (size_cache_B > L1) { + Kc_blocks = (int64_t)std::floor(L1 / (Kr * Nr * num_byte_B)); + } + + float min_Mc_ratio = 2; + int64_t min_Mc_blocks = std::ceil(min_Mc_ratio * Mr / Nr); + auto Kt_bytes = Kt_blocks * Kr * num_byte_A; + if (min_Mc_blocks * Mr * Kt_bytes < L2) { + Mc_blocks = std::min(Mt_blocks, (int64_t)std::floor(L2 / (Mr * Kt_bytes))); + Nc_blocks = 1; + } else { + Mc_blocks = Mt_blocks; + Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); + auto Nc_bytes = Nc_blocks * Nr * 4; + auto Kc_bytes = Kc_blocks * Kr * num_byte_A; + if (Mc_blocks * Mr * (Kc_bytes + Nc_bytes) > L2) { + auto M_max = (std::sqrt(Kc_bytes * Kc_bytes + 16 * L2) - Kc_bytes) / 8; + if (M_max < Mc_blocks * Mr) { + Mc_blocks = (int64_t)std::floor(M_max / Mr); + Nc_blocks = std::min((int64_t)std::ceil((float)Mc_blocks * Mr / Nr), Nt_blocks); + } + } + } +} + +template +void mm_get_cache_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t Mr, + int64_t Nr, + int64_t Kr, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& Mc_blocks, + int64_t& Nc_blocks, + int64_t& Kc_blocks, + uint32_t L1_cache_size, + uint32_t L2_cache_size) { + thread_local std::map< + std::tuple, + std::tuple> cache; + auto key = std::make_tuple(num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, L1_cache_size, L2_cache_size); + auto it = cache.find(key); + if (it != cache.end()) { + std::tie(Mc_blocks, Nc_blocks, Kc_blocks) = it->second; + return; + } else { + _mm_get_cache_blocking( + num_threads, M, N, K, Mr, Nr, Kr, Mt_blocks, Nt_blocks, Kt_blocks, Mc_blocks, Nc_blocks, Kc_blocks, L1_cache_size, L2_cache_size); + cache[key] = std::make_tuple(Mc_blocks, Nc_blocks, Kc_blocks); + } } struct amx_tilecfg { diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 2bce16b2ad1538..a237924b9182d4 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -111,11 +111,10 @@ def make_kernel_render( def header(self) -> IndentedBuffer: res = IndentedBuffer() res.writeline(codecache.cpp_prefix()) - res.splice( - """ - #include "c10/util/Unroll.h" - """ - ) + # TODO: add c10::ForcedUnroll test to test_aoti_abi_check + res.splice("""#include """) + if config.abi_compatible: + res.splice("""#include """) enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", "win32", diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 4db3e7b699ec63..e72a895dc44d59 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -10,6 +10,7 @@ from .. import config, cpp_builder, ir, lowering as L from ..autotune_process import CppBenchmarkRequest +from ..loop_body import LoopBody from ..select_algorithm import PartialRender from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix from ..virtualized import V @@ -106,7 +107,7 @@ def hook(): def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper = V.graph.wrapper_code _, call_args, arg_types = self.args.cpp_argdefs() - wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) + wrapper.generate_kernel_call(name, call_args, gpu=False, arg_types=arg_types) def dtype(self, node: ir.Buffer) -> str: return DTYPE_TO_CPP[node.get_dtype()] @@ -239,7 +240,13 @@ def fn(*args): node.make_loader()(new_args).value, ) - body = ir.LoopBody(fn, (list(var_ranges.keys()), ()), var_ranges) + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) bodies.append(body) var_sizes_list.append(var_sizes) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 8dfa3489e7673d..e69bb3637f7249 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -16,6 +16,7 @@ from torch.utils._sympy.value_ranges import ValueRanges from .. import ir +from ..loop_body import LoopBody from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import ops, OpsValue, V from .common import ( @@ -880,3 +881,36 @@ def inner_fn(index): inner_fn=inner_fn, ranges=input_buffer.get_size(), ) + + +def _get_loop_body(fn_list): + if all(isinstance(fn, LoopBody) for fn in fn_list): + loop_bodies = fn_list + else: + if hasattr(fn_list[0], "original_fn"): + # For the case of local buffer, we wrap the fn with localize_function + assert all(hasattr(fn, "original_fn") for fn in fn_list) + assert all( + isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list + ) + loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] + else: + assert all(isinstance(fn, functools.partial) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) + loop_bodies = [fn.args[0]._body for fn in fn_list] + assert loop_bodies is not None + return loop_bodies + + +def _get_dtype_from_loopbodies(loop_bodies): + dtypes = set() + for loop_body in loop_bodies: + graphs = [loop_body.root_block.graph] + [ + body.graph for body in list(loop_body.subblocks.values()) + ] + for graph in graphs: + for node in graph.nodes: + if node.op != "call_method": + continue + dtypes.add(node.meta[OptimizationContext.key].dtype) + return dtypes diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6ced8aefc50f86..4e88a6f91f7e22 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -49,7 +49,6 @@ def __init__(self): self.extern_call_ops = set() self.size = "sizes()" self.stride = "strides()" - self.cuda = False self.supports_intermediate_hooks = False self.outputs_need_copy = set() self.kernel_callsite_id = count() @@ -71,37 +70,42 @@ def __init__(self): def generate_kernel_call( self, - name, + kernel_name: str, call_args, grid=None, device_index=None, - cuda=True, + gpu=True, triton=True, arg_types=None, raw_args=None, grid_fn: str = "grid", triton_meta=None, + autotune_configs=None, grid_extra_kwargs="", ): """ Generates kernel call code. - cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + gpu: Defines whether the backend is GPU. Otherwise the backend is CPU. triton: Defines whether the GPU backend uses Triton for codegen. Otherwise it uses the CUDA language for codegen. Only valid when cuda == True. """ - if cuda: + if gpu: return super().generate_kernel_call( - name, + kernel_name, call_args, grid, device_index, - cuda, + gpu, triton, arg_types, + raw_args, grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, ) else: if config.abi_compatible: @@ -119,9 +123,9 @@ def generate_kernel_call( else: # arg is a scalar new_args.append(arg) - self.writeline(self.wrap_kernel_call(name, new_args)) + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) else: - self.writeline(self.wrap_kernel_call(name, call_args)) + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) def write_constant(self, name, hashed): # include a hash so our code cache gives different constants different files @@ -152,12 +156,9 @@ def write_header(self): ) if config.abi_compatible: - if config.c_shim_version == "1": - self.header.splice("#include ") - else: - self.header.splice( - f"#include " - ) + self.header.splice( + f"#include " + ) self.header.splice( """ #include @@ -244,6 +245,11 @@ class RAIIPyObject { """ ) + @functools.lru_cache(None) # noqa: B019 + def include_extra_header(self, header: str): + # This is needed for cpp to python dtype conversion + self.header.splice(f"#include <{header}>") + def mark_output_type(self): # mark output type to unwrap tensor back to python scalar from ..ir import ShapeAsConstantBuffer @@ -913,7 +919,7 @@ def finalize_prefix(self): self.prefix = cached_dtypes_buffer def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=False + self, name: str, kernel: str, metadata: Optional[str] = None, gpu=False ): self.header.splice(f"\n{kernel}\n") @@ -1129,7 +1135,7 @@ def generate_end(self, result): result.splice( f""" inductor_entry = CppWrapperCodeCache.load_pybinding( - ["std::vector"], cpp_wrapper_src, {self.cuda}, {len(V.graph.graph_outputs)}) + ["std::vector"], cpp_wrapper_src, "{self.device}", {len(V.graph.graph_outputs)}) """ ) @@ -1152,6 +1158,10 @@ def generate_end(self, result): wrapper_body += """ input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) """ + # Release the inputs for memory reuse. + wrapper_body += """ + args.clear() + """ # unwrap output tensor back to python scalar if all(x for x in self.output_is_tensor.values()): @@ -1192,20 +1202,8 @@ def get_c_shim_func_name(self, kernel): kernel_suffix = kernel_tokens[-1] if kernel_suffix == "call": kernel_suffix = kernel_tokens[-2] - if config.c_shim_version == "1": - # For sdpa, we need the v2 version since v1 didn't consider optional arg - # FIXME: no need to do this after we switch to the torchgen-ed C shim - if kernel_suffix == "_scaled_dot_product_flash_attention": - shim_fn = "aoti_torch__scaled_dot_product_flash_attention_v2" - elif kernel_suffix == "_scaled_mm": - shim_fn = "aoti_torch__scaled_mm_v2" - elif kernel_suffix.startswith("wrapped_fbgemm"): - assert self.device == "cpu", "Using wrapped_fbgemm out of CPU!" - shim_fn = f"aoti_torch_cpu_{kernel_suffix}" - else: - shim_fn = f"aoti_torch_{kernel_suffix}" - else: - shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" + + shim_fn = f"aoti_torch_{self.device}_{kernel_suffix}" return shim_fn def generate_c_shim_extern_kernel_call(self, kernel, args): @@ -1215,10 +1213,7 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): self.allow_stack_allocation = False wrapped_args = [] - args_to_print = None - enable_debug_printer = config.aot_inductor.debug_intermediate_value_printer - if enable_debug_printer: - args_to_print = [] + debug_printer_manager = V.graph.wrapper_code.debug_printer for x in args: pieces = x.split(", ") @@ -1232,16 +1227,10 @@ def generate_c_shim_extern_kernel_call(self, kernel, args): if isinstance(piece, str) and piece.startswith( ("buf", "arg", "wrap_with_raii_handle_if_needed") ): - # TODO: The current way to find a 'tensor' type arg is hacky also as mentioned above - # Find a more reliable way to detect tensor kernel args for extern kernel calls - if enable_debug_printer: - if piece.startswith(("buf", "arg")): - args_to_print.append(piece) piece = f"convert_arrayref_tensor_to_tensor({piece})" wrapped_args.append(piece) - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args(args_to_print, kernel, None, None) + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel) self.writeline( @@ -1261,7 +1250,11 @@ def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): def generate_extern_kernel_alloc(self, extern_kernel, args): if config.abi_compatible: - self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + if hasattr(extern_kernel, "outputs"): + # ir.ExternKernelAlloc may have outputs if it returns a tuple + self.generate_c_shim_fallback_kernel(extern_kernel, args) + else: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) else: super().generate_extern_kernel_alloc(extern_kernel, args) @@ -1335,19 +1328,11 @@ def generate_scatter_fallback( # No stack allocation when there is a fallback op self.allow_stack_allocation = False - # TODO: needs updates to use C shim v2 if config.abi_compatible: # call the ABI shim function instead of the ATen one - if config.c_shim_version == "1": - cpp_kernel_name = ( - "aoti_torch_scatter_reduce_out" - if python_kernel_name.startswith("aten.scatter_reduce") - else "aoti_torch_scatter_out" - ) - else: - cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) - # C shim only contains out-variant instead of inplace-variant - cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" inputs_wrapped = [ f"convert_arrayref_tensor_to_tensor({x})" if isinstance(x, str) @@ -1375,7 +1360,7 @@ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): # No stack allocation when there is a fallback op self.allow_stack_allocation = False - # TODO: needs updates to use C shim v2 + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version if config.abi_compatible: # See the comment in codegen_reinterpret_view about why having something like # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding @@ -1495,7 +1480,8 @@ def generate_profiler_mark_wrapper_call(self, stack): 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' ) - def write_triton_header_once(self): + @cache_on_self + def write_triton_header_once(self) -> None: pass def generate_start_graph(self): @@ -1643,6 +1629,11 @@ def make_allocation( f"at::Tensor {name} = at::detail::empty_strided_cuda(" f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" ) + if device.type == "xpu": + return ( + f"at::Tensor {name} = at::detail::empty_strided_xpu(" + f"{size}, {stride}, {dtype_code}, c10::DeviceType::XPU);" + ) return ( f"{self.declare}{name} = {self.namespace}empty_strided(" f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" @@ -1731,7 +1722,7 @@ def create_dtypeview_call(reinterpret_call): ) call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] dtype_name = str(dtype).split(".")[-1] - device_name = "cuda" if data.layout.device.type == "cuda" else "cpu" + device_name = data.layout.device.type get_dtype_function = f"aoti_torch_dtype_{dtype_name}" dtypeview_function = f"aoti_torch_{device_name}_view_dtype" call_strs.append( @@ -2131,8 +2122,9 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed( self.allow_stack_allocation = False def extract_output_name(out): - assert out is not None, "None, i.e. optional output is not supported" - if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): + if out is None: + return None + elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() elif isinstance(out, (list, tuple)): return type(out)(extract_output_name(o) for o in out) @@ -2142,9 +2134,13 @@ def extract_output_name(out): # output_args has the same pytree structure as outputs output_args = None if config.abi_compatible: - output_args = extract_output_name(outputs) - if isinstance(output_args, str): - output_args = [output_args] + if outputs is None: + # outputs is not specified, the default is to write to buf_name + output_args = [buf_name] + else: + output_args = extract_output_name(outputs) + if isinstance(output_args, str): + output_args = [output_args] if V.graph.aot_mode and config.abi_compatible: assert op_overload is not None @@ -2207,10 +2203,24 @@ def load_custom_op_wrapper(self): self.custom_op_wrapper_loaded = True def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): - def generate_py_arg_inner(raw_arg, arg_type): - if isinstance(arg_type, torch.TensorType): + def generate_py_arg_inner(lines, raw_arg, arg_type): + if raw_arg is None: + # Py_None is a singleton, so we have to explicitly incref it here + lines.append("Py_INCREF(Py_None);\n") + return "Py_None" + elif isinstance(arg_type, torch.TensorType): # Store AtenTensorHandle as void* - return f"PyCapsule_New(reinterpret_cast({raw_arg.codegen_reference()}.get()), NULL, NULL)" + base_handle = raw_arg.codegen_reference() + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + lines.append(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var + return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)" + elif isinstance(arg_type, torch.OptionalType): + return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) elif isinstance(arg_type, torch.IntType): # int return f"PyLong_FromLongLong({raw_arg})" @@ -2244,21 +2254,33 @@ def generate_py_arg_inner(raw_arg, arg_type): raise NotImplementedError( f"arg type {arg_type} with raw_arg {raw_arg}, {type(raw_arg)} is not yet supported by custom_op_wrapper" ) + elif isinstance(raw_arg, torch.dtype): + # dtype + self.include_extra_header("torch/csrc/DynamicTypes.h") + return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" else: raise NotImplementedError( f"arg type {arg_type} is not yet supported by custom_op_wrapper" ) - lines = "" + lines = [] if isinstance(arg_type, torch.ListType): assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" - lines += f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + lines.append( + f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + ) for i, elem in enumerate(raw_arg): - lines += f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(elem, arg_type.getElementType())});\n" - lines += f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + lines.append( + f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" + ) + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + ) else: - lines += f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(raw_arg, arg_type)});\n" - return lines + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" + ) + return "".join(lines) def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( self, @@ -2307,6 +2329,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( """ assert op_overload is not None, "op_overload should not be None" + for idx, (raw_arg, schema_arg) in enumerate( zip(raw_args, op_overload._schema.arguments) ): @@ -2328,13 +2351,16 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( else: # result is a tuple of tensors for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue lines += f""" {output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" declarations_before_scope = [ f"RAIIAtenTensorHandle {output_arg};" - for idx, output_arg in enumerate(output_args) + for output_arg in output_args + if output_arg is not None ] scope_gil_acquire = self.generate_scoped_gil_acquire( declarations_before_scope, lines @@ -2377,12 +2403,12 @@ def generate_reset_kernel_saved_flags(self): def generate_save_uncompiled_kernels(self): pass - def c_type_for_prim_type(self, type_) -> str: + def c_type_for_prim_type(self, val, type_) -> str: assert ( config.abi_compatible ), "c_type_for_prim_type is only used in ABI compatible mode" if isinstance(type_, torch.OptionalType): - return f"{self.c_type_for_prim_type(type_.getElementType())}*" + return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" elif isinstance(type_, torch.TensorType): return "AtenTensorHandle" elif isinstance(type_, (torch.IntType, torch.SymIntType)): @@ -2393,6 +2419,22 @@ def c_type_for_prim_type(self, type_) -> str: return "int32_t" elif isinstance(type_, torch.FloatType): return "double" + elif isinstance(type_, torch.NumberType): + if isinstance(val, bool): + return "int32_t" + elif isinstance(val, int): + return "int64_t" + elif isinstance(val, float): + return "double" + elif val is None: + # This could happen when val is an optional value + return "double" + else: + raise AssertionError( + f"Unexpected type in c_type_for_prim_type: {type_=}" + ) + elif isinstance(type_, torch.StringType): + return "const char*" else: raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") @@ -2481,10 +2523,10 @@ def val_to_arg_str(self, val, type_=None) -> str: return f"&{var_name}, {aux}" else: self.writeline( - f"{self.c_type_for_prim_type(element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" ) return f"&{var_name}" - elif config.c_shim_version == "2": + else: # type_ is Optional[Tensor] # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim base_handle = self.val_to_arg_str(val, element_type) @@ -2492,19 +2534,13 @@ def val_to_arg_str(self, val, type_=None) -> str: base_handle = ( f"convert_arrayref_tensor_to_tensor({base_handle})" ) - if base_handle.startswith( - ( - "convert_arrayref_tensor_to_tensor", - "wrap_with_raii_handle_if_needed", - ) - ): - # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to - # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. - tmp_var_name = f"var_{next(self.arg_var_id)}" - self.writeline( - f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};" - ) - base_handle = tmp_var_name + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + self.writeline(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var var_name = f"var_{next(self.arg_var_id)}" self.writeline( f"AtenTensorHandle {var_name} = {base_handle}.get();" @@ -2524,12 +2560,12 @@ def val_to_arg_str(self, val, type_=None) -> str: # Zero-size array is not supported in the C or C++ standard, so # we declare a null pointer for it. self.writeline( - f"const {self.c_type_for_prim_type(element_type)}* {var_name} = nullptr;" + f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" ) else: result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" self.writeline( - f"const {self.c_type_for_prim_type(element_type)} {var_name}[] = {result};" + f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" ) # Need to pass the array length because we can't use std::vector return f"{var_name}, {len(val)}" @@ -2537,3 +2573,20 @@ def val_to_arg_str(self, val, type_=None) -> str: return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" return self.val_to_arg_str_for_prim_type(val, type_) + + def create_tmp_raii_handle_var(self, base_handle): + if base_handle.startswith( + ( + "convert_arrayref_tensor_to_tensor", + "wrap_with_raii_handle_if_needed", + ) + ): + # wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to + # explicitly store it. Otherwise, it will be destroyed before the fallback kernel call. + tmp_var_name = f"var_{next(self.arg_var_id)}" + return ( + tmp_var_name, + f"RAIIAtenTensorHandle {tmp_var_name} = {base_handle};\n", + ) + else: + return "", "" diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py similarity index 56% rename from torch/_inductor/codegen/cpp_wrapper_cuda.py rename to torch/_inductor/codegen/cpp_wrapper_gpu.py index d0143844f432ca..5719d3eba589f8 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cuda.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -2,7 +2,7 @@ import functools import os from itertools import chain, count -from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import sympy @@ -12,11 +12,11 @@ from .. import config from ..codecache import CudaKernelParamCache -from ..utils import DeferredLineBase +from ..utils import DeferredLineBase, get_gpu_type from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper -from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header -from .cpp_utils import DTYPE_TO_CPP +from .common import get_device_op_overrides +from .cpp_utils import cexpr, DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu from .wrapper import SymbolicCallArg @@ -25,9 +25,9 @@ from ..graph import GraphLowering -class DeferredCudaKernelLine(DeferredLineBase): +class DeferredGpuKernelLine(DeferredLineBase): """ - When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels to be tuned and stored as cubin files, so use a deferred line to backfill those information """ @@ -58,19 +58,117 @@ def __call__(self): return self.line_template % tuple(params[key] for key in self.keys) def _new_line(self, line): - return DeferredCudaKernelLine(self.kernel_name, line, self.keys) + return DeferredGpuKernelLine(self.kernel_name, line, self.keys) -class CppWrapperCuda(CppWrapperCpu): +class DeferredGpuDefaultGrid: + """ + A container for the default grid, which may be used by DeferredCudaGridLine + """ + + def __init__( + self, + kernel_name: str, + grid, + grid_callable: Optional[Callable[..., Any]] = None, + **grid_extra_kwargs, + ): + self.kernel_name = kernel_name + self.grid = grid + self.grid_callable = grid_callable + self.grid_extra_kwargs = grid_extra_kwargs + + def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): + if isinstance(grid, (list, tuple)): + return [self._process_grid(e) for e in grid] + else: + return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid + + def __call__(self): + grid = self.grid + assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" + grid = self._process_grid(grid) + grid_callable = self.grid_callable or default_grid + if not self.grid_extra_kwargs: + grid_fn = grid_callable(*grid) + else: + grid_fn = grid_callable(*grid, **self.grid_extra_kwargs) + + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + block_cfg = { + "XBLOCK": params["x_block"], + "YBLOCK": params["y_block"], + "ZBLOCK": params["z_block"], + } + return grid_fn(block_cfg) + + +class DeferredGpuGridLine(DeferredLineBase): + """ + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred line to backfill those information + """ + + def __init__( + self, + kernel_name: str, + grid_var: str, + grid, + autotune_configs, + ): + super().__init__("") + self.kernel_name = kernel_name + self.grid_var = grid_var + self.grid = grid + self.autotune_configs = autotune_configs + + def __call__(self): + params = CudaKernelParamCache.get(self.kernel_name) + assert ( + params is not None + ), f"{self.kernel_name} not found in CudaKernelParamCache" + + if self.autotune_configs is not None: + # This indicates the Triton kernel is a user-defined one. + grid = None + if len(self.grid) == 1: + grid = self.grid[0] + else: + for i, c in enumerate(self.autotune_configs): + if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): + grid = self.grid[i] + break + assert grid is not None + elif isinstance(self.grid, DeferredGpuDefaultGrid): + grid = self.grid() + else: + grid = self.grid + + assert len(grid) != 0, "Grid can't be empty" + grid_args_str = ", ".join( + [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + ) + return f" Grid {self.grid_var} = Grid({grid_args_str});" + + def _new_line(self, line): + return DeferredGpuGridLine( + self.kernel_name, self.grid_var, self.grid, self.autotune_configs + ) + + +class CppWrapperGpu(CppWrapperCpu): """ Generates cpp wrapper for running on GPU and calls CUDA kernels """ def __init__(self) -> None: - self.device = "cuda" + self.device = get_gpu_type() + self.device_codegen = get_device_op_overrides(self.device) super().__init__() self.grid_id = count() - self.cuda = True def write_header(self): if V.graph.is_const_graph: @@ -81,26 +179,32 @@ def write_header(self): self.header.splice("#include ") if config.abi_compatible: + self.header.splice(self.device_codegen.abi_compatible_header()) + else: self.header.splice( - "#include " + maybe_hipify_code_wrapper(self.device_codegen.kernel_header()) ) - else: - self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_header())) - self.header.splice(maybe_hipify_code_wrapper(cuda_kernel_driver())) + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) def write_get_raw_stream(self, index, graph=None): name = f"stream{index}" - self.writeline(maybe_hipify_code_wrapper(f"cudaStream_t {name};")) self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream({index}, (void**)&{name}));" + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({index}, (void**)&{name}));" ) return name def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + self, name: str, kernel: str, metadata: Optional[str] = None, gpu=True ): - if not cuda: - return super().define_kernel(name, kernel, metadata, cuda) + if not gpu: + return super().define_kernel(name, kernel, metadata, gpu) def generate(self, is_inference): self.prefix.writeline("\n") @@ -110,13 +214,21 @@ def generate(self, is_inference): sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]), ): self.prefix.writeline( - maybe_hipify_code_wrapper(f"static CUfunction {kernel} = nullptr;") + maybe_hipify_code_wrapper( + f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;" + ) ) self.prefix.writeline("\n") return super().generate(is_inference) def generate_user_defined_triton_kernel( - self, kernel_name, raw_args, grid, configs, triton_meta, constexprs + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, ): # in C++ wrapper, we don't pass constexpr args, as they don't # get added as parameters to the PTX code compiled from the @@ -124,20 +236,6 @@ def generate_user_defined_triton_kernel( raw_args = [ raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs ] - - assert len(grid) != 0 - if len(grid) == 1: - grid_decision = grid[0] - else: - meta = CudaKernelParamCache.get(kernel_name) - assert meta is not None - grid_decision = None - for i, c in enumerate(configs): - if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()): - grid_decision = grid[i] - break - assert grid_decision is not None - args = [self.val_to_arg_str(v) for v in raw_args] arg_types = [ arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg) @@ -147,10 +245,12 @@ def generate_user_defined_triton_kernel( kernel_name, args, arg_types=arg_types, - grid=grid_decision, - cuda=True, + raw_args=raw_args, + grid=grid, + gpu=True, triton=True, triton_meta=triton_meta, + autotune_configs=configs, ) @functools.lru_cache(None) # noqa: B019 @@ -163,11 +263,15 @@ def generate_load_kernel_once( kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name self.writeline(f"if ({kernel_var_name} == nullptr) {{") self.writeline( - DeferredCudaKernelLine( + DeferredGpuKernelLine( kernel_name, - kernel_var_name + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" + """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" if V.graph.aot_mode - else kernel_var_name + """ = loadKernel("%s", "%s", %s);""", + else """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s);""", keys, ) ) @@ -205,7 +309,9 @@ def generate_args_decl(self, call_args, arg_types): else: if config.abi_compatible: self.writeline( - maybe_hipify_code_wrapper(f"CUdeviceptr {var_name};") + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_device_ptr()} {var_name};" + ) ) self.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" @@ -213,7 +319,8 @@ def generate_args_decl(self, call_args, arg_types): else: self.writeline( maybe_hipify_code_wrapper( - f"CUdeviceptr {var_name} = reinterpret_cast({arg}.data_ptr());" + f"{self.device_codegen.cpp_device_ptr()} {var_name} = \ + reinterpret_cast<{self.device_codegen.cpp_device_ptr()}>({arg}.data_ptr());" ) ) elif arg_type in (sympy.Integer, int): @@ -228,58 +335,58 @@ def generate_args_decl(self, call_args, arg_types): def generate_default_grid( self, - name: str, + kernel_name: str, grid: List[Any], - cuda: bool = True, + gpu: bool = True, grid_callable: Optional[Callable[..., Any]] = None, **grid_extra_kwargs, ): """ Generate grid configs for launching a CUDA kernel using the grid - function from triton_heuristics. + function from triton_heuristics. Because its computation needs + to read kernel config after autotune, it is done in a deferred way + using DeferredGpuDefaultGrid. """ - if not cuda: + if not gpu: return grid - assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list" - grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid] - grid_callable = grid_callable or default_grid - if not grid_extra_kwargs: - grid_fn = grid_callable(*grid) - else: - grid_fn = grid_callable(*grid, **grid_extra_kwargs) - params = CudaKernelParamCache.get(name) - assert ( - params is not None - ), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}" - block_cfg = { - "XBLOCK": params["x_block"], - "YBLOCK": params["y_block"], - "ZBLOCK": params["z_block"], - } - return grid_fn(block_cfg) + return DeferredGpuDefaultGrid( + kernel_name, grid, grid_callable, **grid_extra_kwargs + ) def generate_kernel_call( self, - kernel_name, + kernel_name: str, call_args, grid=None, device_index=None, - cuda=True, + gpu=True, triton=True, arg_types=None, raw_args=None, grid_fn: str = "grid", triton_meta=None, + autotune_configs=None, grid_extra_kwargs="", ): assert arg_types is not None and len(call_args) == len( arg_types ), "call_args and arg_types do not match" - if not cuda: - # Even in CppWrapperCuda, we may see cpp kernels + if not gpu: + # Even in CppWrapperGpu, we may see cpp kernels return super().generate_kernel_call( - kernel_name, call_args, grid, device_index, cuda, triton, arg_types + kernel_name, + call_args, + grid, + device_index, + gpu, + triton, + arg_types, + raw_args, + grid_fn, + triton_meta, + autotune_configs, + grid_extra_kwargs, ) device_index, call_args = self.prepare_triton_kernel_call( @@ -299,41 +406,37 @@ def generate_kernel_call( call_args = [arg for i, arg in enumerate(call_args) if i not in equal_to_1] arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] - call_args = self.generate_args_decl(call_args, arg_types) + call_args_str = self.generate_args_decl(call_args, arg_types) kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" - self.writeline(f"void* {kernel_args_var}[] = {{{call_args}}};") + self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};") stream = ( "stream" if V.graph.aot_mode else self.write_get_raw_stream(device_index, V.graph) ) - grid_name = f"{kernel_name}_grid_{next(self.grid_id)}" - assert isinstance( - grid, (list, tuple) - ), f"expected grid to be a list or tuple but got: {grid=}" - - grid = [V.graph.sizevars.simplify(item) for item in grid] - grid_uses_symbolic_shapes = any(item.free_symbols for item in grid) - grid_args = [self.expr_printer(item) for item in grid] - grid_args_str = ", ".join(grid_args) - self.writeline(f"Grid {grid_name} = Grid({grid_args_str});") - - if grid_uses_symbolic_shapes: - self.writeline(f"if ({grid_name}.is_non_zero()) {{") - kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + grid_var = f"{kernel_name}_grid_{next(self.grid_id)}" self.writeline( - DeferredCudaKernelLine( - kernel_name, - r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( - kernel_var_name, - f"{grid_name}.grid_x", - f"{grid_name}.grid_y", - f"{grid_name}.grid_z", - kernel_args_var, - stream, - ), - ("num_warps", "shared_mem"), - ), + DeferredGpuGridLine(kernel_name, grid_var, grid, autotune_configs) ) - if grid_uses_symbolic_shapes: + + kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + # add debug printer code for all triton kernel related calls + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline(f"if ({grid_var}.is_non_zero()) {{") + self.writeline( + DeferredGpuKernelLine( + kernel_name, + r" launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format( + kernel_var_name, + f"{grid_var}.grid_x", + f"{grid_var}.grid_y", + f"{grid_var}.grid_z", + kernel_args_var, + stream, + ), + ("num_warps", "shared_mem"), + ), + ) self.writeline("}") diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index d6472a48f1e08c..dfb0b159e2f7b8 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -180,7 +180,7 @@ def call_kernel( wrapper.generate_kernel_call( name, call_args, - cuda=True, + gpu=True, triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 7ff99b871c82d6..011e503a7b889d 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +import torch + from ..common import DeviceOpOverrides, register_device_op_overrides @@ -15,5 +17,120 @@ def synchronize(self): def device_guard(self, device_idx): return f"torch.cuda._DeviceGuard({device_idx})" + def cpp_device_guard(self): + return "at::cuda::CUDAGuard" + + def cpp_aoti_device_guard(self): + return "AOTICudaGuard" + + def cpp_stream_guard(self): + return "at::cuda::CUDAStreamGuard" + + def cpp_aoti_stream_guard(self): + return "AOTICudaStreamGuard" + + def cpp_getStreamFromExternal(self): + return "at::cuda::getStreamFromExternal" + + def kernel_header(self): + source_codes = """ + #include + #include + #include + """ + return source_codes + + def kernel_driver(self): + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + cuGetErrorString(code, &msg); \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + namespace { + + struct Grid { + Grid(uint32_t x, uint32_t y, uint32_t z) + : grid_x(x), grid_y(y), grid_z(z) {} + uint32_t grid_x; + uint32_t grid_y; + uint32_t grid_z; + + bool is_non_zero() { + return grid_x > 0 && grid_y > 0 && grid_z > 0; + } + }; + + } // anonymous namespace + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + def abi_compatible_header(self): + return "#include " + + def cpp_stream_type(self): + return "cudaStream_t" + + def aoti_get_stream(self): + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self): + return "CUfunction" + + def cpp_device_ptr(self): + return "CUdeviceptr" + register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 90971045379656..9273eeb5f9824f 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -30,7 +30,7 @@ def __init__(self, scheduler: Scheduler) -> None: self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) - def get_backend_features(self, device): + def get_backend_features(self, device): # type:ignore[override] return self._triton_scheduling.get_backend_features(device) def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index a44fa5d1d8b1fa..32d4d55e58545f 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -2,92 +2,239 @@ from __future__ import annotations import functools +import logging +import os +from enum import Enum from typing import List, Optional +from torch import dtype as torch_dtype + from .. import config from ..virtualized import V -from .common import TensorArg +from .multi_kernel import MultiKernel -class DebugPrinterManager: - DEBUG_FILTER_DEFAULT_PRINT_ALL = "default" +log = logging.getLogger(__name__) + + +def _print_debugging_tensor_value_info(msg, arg): + # helper for printing debugging stats for intermediate tensor values + # at jit inductor level codegen + max_numel_to_print = 64 + print(msg) + numel = arg.float().numel() + # print the debug printing stats + if numel <= max_numel_to_print: + print(arg) + print("Number of elements: ", numel) + print("Size: ", arg.float().size()) + print("Dtype: ", arg.float().mean().item()) + print("Mean: ", arg.float().mean().item()) + print("Min: ", arg.float().min().item()) + print("Max: ", arg.float().max().item()) + print("Std: ", arg.float().std().item()) + +# AOTI debug printing related configs +class IntermediateValueDebuggingLevel(Enum): + # OFF: No intermediate tensor value debug info will be printed or saved. + OFF = "0" + # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. + SAVE_ONLY = "1" + # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. + PRINT_ONLY = "2" + # LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed. + # This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc. + PRINT_KERNEL_NAMES_ONLY = "3" + + +class DebugPrinterManager: def __init__( self, - enable_debug_printer: bool, - args_to_print: Optional[List[str]] = None, + debug_printer_level, + args_to_print_or_save: Optional[List[str]] = None, kernel_name: str = "", kernel=None, arg_signatures: Optional[List[type]] = None, ): - self.enable_debug_printer = enable_debug_printer - if args_to_print is None: - args_to_print = [] - self.args_to_print = args_to_print + self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + if args_to_print_or_save is None: + args_to_print_or_save = [] + self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name self.arg_signatures: Optional[List[type]] = None self.kernel = kernel - self.filtered_kernel_names_to_print = self.get_debug_filtered_kernel_names() + self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() def __enter__(self): - if self.enable_debug_printer: - V.graph.all_codegen_kernel_names.add(self.kernel_name) - self.codegen_intermediate_tensor_value_printer( - self.args_to_print, + self._perform_debug_print_or_save_helper( + self.args_to_print_or_save, + self.kernel_name, + before_launch=True, + arg_signatures=self.arg_signatures, + ) + + def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): + self._perform_debug_print_or_save_helper( + args_to_print_or_save, + kernel_name, + before_launch=False, + arg_signatures=arg_signatures, + ) + + def _perform_debug_print_or_save_helper( + self, + args_to_print_or_save, + kernel_name, + before_launch, + arg_signatures: Optional[List[type]] = None, + ): + if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: + return + if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: + # by default save all the tensor values before launch + self.codegen_intermediate_tensor_value_save( + self.args_to_print_or_save, self.kernel_name, - before_launch=True, + before_launch, arg_signatures=self.arg_signatures, ) - - def __exit__(self, args_to_print, kernel_name, arg_signatures): - if self.enable_debug_printer: - self.codegen_intermediate_tensor_value_printer( - self.args_to_print, + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # by default print all the tensor values before launch + self.codegen_intermediate_tensor_value_print( + self.args_to_print_or_save, self.kernel_name, - before_launch=False, + before_launch, arg_signatures=self.arg_signatures, ) + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + # Print all kernel names to the console only + self.codegen_intermediate_tensor_value_print( + [], + self.kernel_name, + before_launch, + ) + + @functools.lru_cache # noqa: B019 + def _get_debug_filtered_kernel_names(self) -> List[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] + return [ + x.strip() + for x in config.aot_inductor.filtered_kernel_names.lower().split(",") + ] def set_printer_args( self, - args_to_print: List[str], + args_to_print_or_save: List[str], kernel_name: str, arg_signatures: Optional[List[type]], kernel, + kernel_type=None, ): - self.args_to_print = args_to_print + # Note: MultiKernel debug printing is not supported for now + if isinstance(kernel, MultiKernel): + log.info( + "MultiKernel type is not supported in AOTI debug printer tool yet." + ) + self.debug_printer_level = IntermediateValueDebuggingLevel.OFF + + # Note: if the kernel type is an extern kernel, we do a special handling to get the list of args_to_print_or_save + # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls + if kernel_type == "extern": + args_to_print_or_save_extern = [] + for arg in args_to_print_or_save: + if arg.startswith(("buf", "arg")): + args_to_print_or_save_extern.append(arg) + self.args_to_print_or_save = args_to_print_or_save_extern + else: + self.args_to_print_or_save = args_to_print_or_save self.kernel_name = kernel_name self.arg_signatures = arg_signatures self.kernel = kernel - @functools.lru_cache # noqa: B019 - def get_debug_filtered_kernel_names(self) -> List[str]: - return [ - x.strip() - for x in config.aot_inductor.filtered_kernel_names.lower().split(",") - ] + def codegen_intermediate_tensor_value_save( + self, + args_to_save, + kernel_name, + before_launch=True, + arg_signatures: Optional[List[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_save): + if arg_signatures is not None and not isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + continue + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + if config.abi_compatible: + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) + else: + # TODO: add non-abi compatible mode debug printing info + pass + else: + cwd = os.getcwd() + saved_dir = cwd + "/tmp/jit_inductor/" + if not os.path.exists(saved_dir): + log.info( + "Creating directory to save inductor intermediate tensor values." + ) + os.makedirs(saved_dir) + # Save the model to the directory + saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt" + log.info( + "Saved intermediate tensor %s for %s to %s", + arg, + kernel_name, + saved_path, + ) + line = f"torch.save({arg}, '{saved_path}')" + V.graph.wrapper_code.writeline(line) - def codegen_intermediate_tensor_value_printer( + def codegen_intermediate_tensor_value_print( self, args_to_print, kernel_name, before_launch=True, arg_signatures: Optional[List[type]] = None, ) -> None: + launch_prefix = "before_launch" if before_launch else "after_launch" + + # if the debug printing level is PRINT_KERNEL_NAMES_ONLY + # we only print the kernel name to the console + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + if V.graph.cpp_wrapper: + if config.abi_compatible: + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix}: {kernel_name} ]");' + ) + V.graph.wrapper_code.writeline('printf("\\n");') + return + for i, arg in enumerate(args_to_print): if arg_signatures is not None and not isinstance( - arg_signatures[i], TensorArg - ): - continue - if ( - len(self.filtered_kernel_names_to_print) > 0 - and self.filtered_kernel_names_to_print[0] - != self.DEBUG_FILTER_DEFAULT_PRINT_ALL - and kernel_name not in self.filtered_kernel_names_to_print + arg_signatures[i], torch_dtype ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type continue - launch_prefix = "before_launch" if before_launch else "after_launch" - if V.graph.cpp_wrapper: + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, + # check if filtered kernel name list is provided + if ( + len(self.filtered_kernel_names_to_print) > 0 + and kernel_name not in self.filtered_kernel_names_to_print + ): + continue + if config.abi_compatible: V.graph.wrapper_code.writeline( f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' @@ -96,5 +243,6 @@ def codegen_intermediate_tensor_value_printer( # TODO: add non-abi compatible mode debug printing info pass else: - line = f"print('{launch_prefix} - {kernel_name} - {arg}', {arg})" - V.graph.wrapper_code.writeline(line) + V.graph.wrapper_code.writeline( + f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})' + ) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 20968a57a4444f..337aa544b0d10a 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -1638,7 +1638,7 @@ def call_kernel(self, name: str, node=None): wrapper.generate_kernel_call( name, call_args, - cuda=False, # grid/stream is handled internally in halide + gpu=False, # grid/stream is handled internally in halide ) def generate_assert(self, check): diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 020c38dcc77e4e..6081530fd98e54 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -309,7 +309,7 @@ def inner(): return inner return [ - benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40, fast_flush=True) + benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40) for kernel in self.kernels ] diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index fff132a9ab7c71..d247103a9d4019 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -57,6 +57,7 @@ class CKGemmTemplate(CKTemplate): LDB, std::array{ {{'LDD' if has_bias else ''}} }, LDC, + 1, // kBatch PassThrough {}, // a_elementwise_op PassThrough {}, // b_elementwise_op {{epilogue}} // c_elementwise_op diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index 9029dbe644a5ba..ace9910685ca2b 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -157,7 +157,7 @@ def call_kernel( name, kernel_args, device_index=current_device.index, - cuda=True, + gpu=True, triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 9a2c7982fdbcb8..148d062b7a7a3e 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1064,13 +1064,12 @@ def can_fuse(self, node1, node2): def generate_node_schedule(self, nodes, numel, rnumel): node_schedule: List[Any] = [] - current_loop_writes: OrderedSet[str] = OrderedSet() - + done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() # Writes with a reduced shape, meaning they are only present once the # reduction loop has ended - current_loop_reduced_writes: OrderedSet[str] = OrderedSet() - current_loop_has_writes = False - done: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + current_loop_buffer_usage: OrderedSet[str] = OrderedSet() + maybe_split_index: Optional[int] = None def fits_in_main_body(n): _, (node_numel, node_rnumel) = n.group @@ -1082,11 +1081,17 @@ def fits_outside_reduction(n): _, (node_numel, node_rnumel) = n.group return node_numel == numel and node_rnumel == 1 and rnumel != 1 + def expect_improved_memory_usage(n): + for read in n.read_writes.reads: + if read.name in current_loop_buffer_usage: + return True + return False + def schedule_node_in_loop(n): - nonlocal current_loop_has_writes done.add(n) node_schedule.append(n) - current_loop_has_writes = True + current_loop_buffer_usage.update([x.name for x in n.read_writes.reads]) + # A scan is modelled as a reduction in the scheduler but has a # full sized output that can be used inside the loop body if ( @@ -1095,50 +1100,53 @@ def schedule_node_in_loop(n): and isinstance(n.node, ir.ComputedBuffer) and not isinstance(n.node.data, ir.Scan) ): - current_loop_reduced_writes.add(n.get_name()) + not_ready_yet_nodes.add(n.get_name()) + else: # this node is available within the loop + current_loop_buffer_usage.update([x.name for x in n.read_writes.writes]) @contextlib.contextmanager def end_current_reduction_loop(): - nonlocal current_loop_has_writes - if current_loop_has_writes: - # flush out any other runnable nodes to reduce number of loops - for other_node in nodes[index + 1 :]: - if ( - node not in done - and fits_in_main_body(other_node) - and not (current_loop_reduced_writes & other_node.ancestors) - ): - schedule_node_in_loop(node) - + nonlocal maybe_split_index if node_schedule and node_schedule[-1] is EnableReduction: node_schedule.pop() else: node_schedule.append(DisableReduction) + if maybe_split_index: + node_schedule.insert(maybe_split_index, DisableReduction) + node_schedule.insert(maybe_split_index + 1, EnableReduction) + maybe_split_index = None yield node_schedule.append(EnableReduction) - current_loop_reduced_writes.clear() - current_loop_has_writes = False + not_ready_yet_nodes.clear() + current_loop_buffer_usage.clear() + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not not_ready_yet_nodes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(not_ready_yet_nodes) for index, node in enumerate(nodes): if node in done: continue done.add(node) - def requires_closing_previous_reduction(node, node_schedule): - if rnumel == 1: - return False - if not current_loop_reduced_writes & node.ancestors: - return False - assert node_schedule and not isinstance( - node_schedule[-1], (EnableReduction, DisableReduction) - ) - return bool(current_loop_reduced_writes) - if fits_in_main_body(node): if requires_closing_previous_reduction(node, node_schedule): with end_current_reduction_loop(): pass # need to start a new reduction loop + if current_loop_buffer_usage and not expect_improved_memory_usage(node): + # If we don't improve memory usage, then it is better to split into two loops + maybe_split_index = maybe_split_index or len(node_schedule) + else: + # Memory usage got improved, cancel the loop split + maybe_split_index = None + schedule_node_in_loop(node) elif fits_outside_reduction(node): with end_current_reduction_loop(): @@ -1393,25 +1401,7 @@ def _node_has_sort(node): node.mark_run() self.codegen_comment(node_schedule) - - # debug printing values of intermediate tensors - # Note: MultiKernel debug printing is not supported for now - enable_debug_printer = ( - config.aot_inductor.debug_intermediate_value_printer - and not isinstance(final_kernel, MultiKernel) - ) - _, call_args, arg_signatures, _ = ( - final_kernel.args.python_argdefs() - if not isinstance(final_kernel, MultiKernel) - else [None, [], None, None] - ) - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.enable_debug_printer = enable_debug_printer - debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_signatures, final_kernel - ) - with debug_printer_manager: - final_kernel.call_kernel(final_kernel.kernel_name) + final_kernel.call_kernel(final_kernel.kernel_name) if config.nan_asserts: final_kernel.codegen_nan_check() @@ -1534,15 +1524,7 @@ def codegen_template( kernel_name = self.define_kernel(src_code, node_schedule, kernel) self.codegen_comment(node_schedule) - - # debug printing values of intermediate tensors - _, call_args, arg_signatures, _ = kernel.args.python_argdefs() - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_signatures, kernel - ) - with debug_printer_manager: - kernel.call_kernel(kernel_name, template_node.node) + kernel.call_kernel(kernel_name, template_node.node) V.graph.removed_buffers |= kernel.removed_buffers V.graph.inplaced_to_remove |= kernel.inplaced_to_remove diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 65dd3608d2a841..c30f2d0bddc2f1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -385,6 +385,12 @@ def _print_TruncToInt(self, expr): f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) + def _print_Float(self, expr): + # Use a tensor here to get float64. Otherwise the constant is + # truncated to float32. + ret = f"tl.full([1], {expr}, tl.float64)" + return ret + def _print_ToFloat(self, expr): assert len(expr.args) == 1 return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" @@ -2496,7 +2502,7 @@ def codegen_kernel_benchmark(self, num_gb, grid=None): result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") @@ -2624,6 +2630,20 @@ def codegen_kernel(self, name=None): mutated_args.add(self.args.inplace_buffers[mutation].inner_name) if mutation in self.args.output_buffers: mutated_args.add(self.args.output_buffers[mutation]) + + # workspace arguments are mutated, but are not marked as mutations in self.mutations + # because their buffers are added during codegen, and aren't tracked during + # lowering/scheduling. So we add them as mutated_args explicitly below. + # + # In the logic below, we only mark the workspaces a mutated if they are marked with + # zero_fill: that's because, if we don't expect the buffer to be pre-filled with + # zeros, then, although we still mutate the data, we don't care about those + # mutations because we don't make any assumptions about the contents of the + # workspace buffer. + for argname, arg in zip(argdefs, signature): + if isinstance(arg, WorkspaceArg) and arg.zero_fill: + mutated_args.add(argname) + mutated_args = sorted(mutated_args) triton_meta_signature = signature_to_meta( @@ -2814,7 +2834,7 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): call_args, grid, current_device.index, - cuda=True, + gpu=True, triton=True, arg_types=arg_types, grid_fn=self._get_grid_fn(), @@ -3046,7 +3066,9 @@ def define_kernel(self, src_code, node_schedule, kernel): return kernel_name def benchmark_fused_nodes(self, nodes): - with preserve_rng_state(): + with preserve_rng_state(), torch.cuda.device( + self.scheduler.get_current_device_or_throw() + ): src_code = self.generate_kernel_code_from_nodes( nodes, benchmark_kernel=True ) @@ -3110,9 +3132,10 @@ def store_cache(): # in the case of mutating/in-placeable second fusion # TODO - would be better as a hook in triton do_bench that reset # the input values between benchmarking - ms = ms - benchmarker.benchmark_gpu( - lambda: wrapped_jit_function.clone_args(*args) - ) + if len(wrapped_jit_function.mutated_arg_names) > 0: + ms = ms - benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args) + ) log.debug( "The fused kernel for %s took %.3f ms to run", diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 4d0f5165e1c56f..134c226d38effe 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -3,9 +3,20 @@ import textwrap from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union - -from sympy import Integer +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) + +from sympy import Integer, Symbol from torch.utils._ordered_set import OrderedSet @@ -16,7 +27,14 @@ from ..scheduler import BaseSchedulerNode from ..utils import Placeholder from ..virtualized import V -from .common import DeferredLine, IndentedBuffer, Kernel, PythonPrinter, SizeArg +from .common import ( + DeferredLine, + IndentedBuffer, + Kernel, + PythonPrinter, + SizeArg, + WorkspaceArg, +) from .simd import SIMDScheduling from .triton import gen_common_triton_imports, TritonKernel from .triton_utils import config_of, signature_to_meta @@ -71,7 +89,9 @@ def _default_custom_combo_kernel_horizontal_partition( not_reduction = [n for n in group_per_dim if n not in reduction] # rnumel > 2048 usually has long execution time # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes - long_reduction = [n for n in reduction if cast(Integer, n.group[-1][-1]) > 2048] + long_reduction = [ + n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 + ] short_reduction = [n for n in reduction if n not in long_reduction] if long_reduction: log.warning( @@ -286,15 +306,18 @@ def _calculate_xblocks( ) -> None: x_numels_list = kernel.x_numels_list for i in range(len(x_numels_list)): - xnumels = ( - x_numels_list[i] - if x_numels_list[i] > 0 - else kernel.min_x_blocks_list[i] + xnumels, no_x_dim = ( + (x_numels_list[i], False) + if isinstance(x_numels_list[i], str) + and cast(str, x_numels_list[i])[0] != "-" + or ( + isinstance(x_numels_list[i], int) + and cast(int, x_numels_list[i]) > 0 + ) + else (kernel.min_x_blocks_list[i], True) ) xblock_str = ( - f"tl.cdiv({xnumels}, XBLOCK)" - if x_numels_list[i] > 0 - else f"{xnumels}" + f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" ) if i == 0: code.splice(f"num_xblocks_{i} = {xblock_str}") @@ -303,19 +326,30 @@ def _calculate_xblocks( @classmethod def grid( - cls, sub_kernel_numels: List[List[int]], x_blocks_list: List[int] + cls, + sub_kernel_numels: List[List[int]], + x_blocks_list: List[Union[str, int]], + dynamic_shape: bool, ) -> Tuple[Any, ...]: - xnumel = x_blocks_list - ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] - znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] + xnumel = list(x_blocks_list) + ynumel: Any = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] + znumel: Any = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] - # TODO: improve 1d/2d mixed cases - ynumel = ( - None if any(e is None for e in ynumel) else max(cast(List[int], ynumel)) - ) - znumel = ( - None if any(e is None for e in znumel) else max(cast(List[int], znumel)) - ) + if dynamic_shape: + ynumel = None if None in ynumel else ynumel + znumel = None if None in znumel else znumel + else: + # TODO: improve 1d/2d mixed cases + ynumel = ( + None + if any(e is None for e in cast(List[Any], ynumel)) + else max(cast(Iterable[int], ynumel)) + ) + znumel = ( + None + if any(e is None for e in cast(List[Any], znumel)) + else max(cast(Iterable[int], znumel)) + ) numels = ( (xnumel,) @@ -352,21 +386,38 @@ def codegen_pid_range( @classmethod def grid( - cls, sub_kernel_numels: List[List[int]], x_blocks_list: List[int] + cls, + sub_kernel_numels: List[List[int]], + x_blocks_list: List[Union[str, int]], + dynamic_shape: bool, ) -> Tuple[Any, ...]: - xnumel = [e[-1] if len(e) > 0 else None for e in sub_kernel_numels] + xnumel = x_blocks_list + # set no_x_dim xnumels to 0 + xnumel_x_dim = [max(e, 0) for e in xnumel] ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels] znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels] # TODO: support 1d/2d mixed cases xnumel = ( - None if any(e is None for e in xnumel) else max(cast(List[int], xnumel)) + None + if any(e is None for e in xnumel) + else xnumel + if dynamic_shape + else max(xnumel_x_dim) # type: ignore[type-var, arg-type] ) ynumel = ( - None if any(e is None for e in ynumel) else max(cast(List[int], ynumel)) + None + if any(e is None for e in ynumel) + else ynumel + if dynamic_shape + else max(ynumel) # type: ignore[type-var, arg-type] ) znumel = ( - None if any(e is None for e in znumel) else max(cast(List[int], znumel)) + None + if any(e is None for e in znumel) + else znumel + if dynamic_shape + else max(znumel) # type: ignore[type-var, arg-type] ) numels = ( @@ -385,8 +436,8 @@ def __init__( self.sub_kernels: List[TritonKernel] = [] self.iter_vars_count = itertools.count() self.grids: List[List[int]] = [] - self.min_x_blocks_list: List[int] = [] - self.x_numels_list: List[int] = [] + self.min_x_blocks_list: List[Union[int, str]] = [] + self.x_numels_list: List[Union[int, str]] = [] self.enable_autotune = enable_autotune self.mixed_sizes = mixed_sizes self.dispatch_class: Optional[ @@ -401,6 +452,7 @@ def __init__( self.block_size_2d = 32 self.num_warps = 8 self.block_size_reduce = 256 + self.dynamic_shape_args: List[str] = [] def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: sub_kernel = triton_kernel @@ -419,6 +471,10 @@ def create_triton_kernel( reduction_hint: ReductionHint, optimize_mask: bool, ) -> TritonKernel: + """ + Only allow optimize_mask=True when 1) sequential dispatch is used, + 2) numels except x dimension are the same for each sub kernel. + """ return TritonKernel( *groups, index_dtype=index_dtype, @@ -451,16 +507,25 @@ def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexp uniquify_block_sizes = [] for tree in sub_kernel.range_trees: simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) - code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + if isinstance(simplified_tree_numel, (Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + else: + assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args + uniquify_block_sizes.append(f"{tree.prefix}numel") if tree.prefix != "r": - grid.append(int(simplified_tree_numel)) + if isinstance(simplified_tree_numel, (Integer, int)): + grid.append(int(simplified_tree_numel)) + else: + grid.append(f"{tree.prefix}numel_{num}") if tree.prefix == "r" and sub_kernel.persistent_reduction: if isinstance(simplified_tree_numel, (Integer, int)): val = int(simplified_tree_numel) else: - continue + raise RuntimeError( + "Dynamic shape on reduction dimension is not supported" + ) val = next_power_of_2(val) code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}") uniquify_block_sizes.append("RBLOCK") @@ -476,16 +541,27 @@ def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. Grid calculation needs to make sure that they are assigned with enough number of blocks. """ - min_x_blocks = 0 - x_numels = 0 + min_x_blocks: Union[int, str] = 0 + x_numels: Union[int, str] = 0 for tree in sub_kernel.range_trees: simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) if tree.prefix == "x": + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" if sub_kernel.no_x_dim: - min_x_blocks = int(simplified_tree_numel) - x_numels = -min_x_blocks + min_x_blocks = x_numels + x_numels = ( + -min_x_blocks + if isinstance(x_numels, int) + else "-" + cast(str, x_numels) + ) else: - x_numels = int(simplified_tree_numel) + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" self.min_x_blocks_list.append(min_x_blocks) self.x_numels_list.append(x_numels) @@ -562,12 +638,15 @@ def get_mutated_args_sub_kernels(self) -> List[str]: def select_dispatch_strategy(self) -> None: if self.dispatch_class is not None: return - if not self.mixed_sizes: + # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch + # Not mixed sizes on y dim technically is ok to use round robin as wells. + if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): + # str in min_x_blocks_list means a dynamic shape self.dispatch_class = ComboKernel.SequentialDispatch return # A negative x_blocks_list element means the kernel is not tunable, # i.e., no_x_dim = True - x_numels_list = [abs(e) for e in self.x_numels_list] + x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] total = max(x_numels_list) * len(x_numels_list) needed = sum(x_numels_list) if needed / total > BLOCK_UTILIZATION: @@ -582,10 +661,12 @@ def jit_line( size_hints: List[int], selected_kernel: TritonKernel, pointwise_with_reduce: bool = False, + signature: Optional[List[Any]] = None, ) -> str: can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) size_dtype = "tl.int32" if can_use_32bit else "tl.int64" - _, _, signature, _ = self.args.python_argdefs() + if signature is None: + _, _, signature, _ = self.args.python_argdefs() for i, sub in enumerate(self.sub_kernels): self.min_x_blocks_sub_kernel(sub, i) self.select_dispatch_strategy() @@ -622,7 +703,7 @@ def jit_line( reduction_hint={reduction_hint}, filename=__file__, triton_meta={triton_meta!r}, - inductor_meta={inductor_meta} + inductor_meta={inductor_meta!r} ) @triton.jit """ @@ -678,6 +759,66 @@ def add_blockd_to_args(self, argdefs: List[str]) -> List[str]: self.block_args = list(block_names.keys()) return argdefs + def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.active_range_trees(): + if not isinstance(tree.numel, (Integer, int)): + # only if it is a dynamic shape + sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) + signature.append(sizearg) + argdefs.append(f"{tree.prefix}numel_{num}") + self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") + return argdefs + + def add_numel_to_call_args_and_grid( + self, name: str, call_args: List[Any], arg_types: List[Any], grid: List[Any] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if isinstance(tree.numel, (Integer, Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr( + name, tree, suffix=str(num) + ) + if tree.prefix != "r": + assert isinstance( + grid[i][num], str + ), f"Grid {grid[i][num]} should be a dynamic shape." + numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" + assert ( + grid[i][num] == numel_sign + numel_name + ), f"numel args mismatch: {grid[i][num]} vs {numel_name}" + grid[i][num] = -expr if numel_sign == "-" else expr + + if tree.prefix != "r" or sub_kernel.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def add_numel_to_call_args_and_grid_benchmark( + self, extra_args: List[Any], grid: Union[List[Any], Tuple[Any, ...]] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + expr = V.graph.sizevars.size_hint(tree.numel) + if tree.prefix != "r": + assert isinstance( + grid[i][num], str + ), f"Grid {grid[i][num]} should be a dynamic shape." + numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" + assert ( + grid[i][num] == numel_sign + numel_name + ), f"grid mismatch: {grid[i][num]} vs {numel_name}" + grid[i][num] = -expr if numel_sign == "-" else expr + if tree.prefix != "r" or sub_kernel.inside_reduction: + extra_args.append(expr) + def codegen_kernel(self, name: Optional[str] = None) -> str: # TODO: is it correct to use the first sub kernel's heuristics? heuristics_list, size_hints_list = [], [] @@ -699,7 +840,8 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: if config.benchmark_combo_kernel: code.splice(self.imports_for_benchmark_kernel()) - argdefs, _, _, _ = self.args.python_argdefs() + argdefs, _, signature, _ = self.args.python_argdefs() + argdefs = self.add_numel_to_args(argdefs, signature) argdefs = self.add_blockd_to_args(argdefs) code.splice( self.jit_line( @@ -707,6 +849,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: size_hints, selected_kernel, pointwise_with_reduce=pointwise_with_reduction, + signature=signature, ) ) code.writeline( @@ -752,7 +895,7 @@ def codegen_kernel_benchmark( var_names = [] for arg_name, arg_sig in zip(call_args, signature): var_name = f"arg_{next(name_cnt)}" - buf = V.graph.get_buffer(arg_name) + buf = V.graph.try_get_buffer(arg_name) if buf: result.writeline( f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long @@ -772,6 +915,12 @@ def codegen_kernel_benchmark( if "seed_offset" in arg_sig.name: symval_hint = 0 result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.scheduler.get_current_device_or_throw() + nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) + result.writeline( + f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" + ) else: raise KeyError( f"Don't find the buffer or const tensor for {arg_name}" @@ -782,15 +931,31 @@ def codegen_kernel_benchmark( result.writelines(["\n", "\n", "def call(args):"]) if grid is None: assert self.dispatch_class is not None - grid_tuple = self.dispatch_class.grid(self.grids, self.x_numels_list) + dynamic_shape = self.dynamic_shape_args != [] + grid_tuple = self.dispatch_class.grid( + self.grids, self.x_numels_list, dynamic_shape + ) + extra_args_str = "" + extra_args: List[Any] = [] + if dynamic_shape: + self.add_numel_to_call_args_and_grid_benchmark(extra_args, grid_tuple) + # convert nested list to list of str + grid_tuple = tuple( + "[" + ", ".join(pexpr(item) for item in e) + ",]" + for e in grid_tuple + ) + extra_args_str = ", ".join(map(str, extra_args)) + ", " + min_blocks = None + else: + min_blocks = max(self.min_x_blocks_list) * len(self.sub_kernels) grid_str = ", ".join(pexpr(item) for item in grid_tuple) grid_extra_kwargs = ( f"num_kernels={len(self.sub_kernels)}, " - f"min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}, " + f"min_blocks={min_blocks}, " f"is_sequential={self.dispatch_class is self.SequentialDispatch}" ) grid_str = f"{grid_str}, {grid_extra_kwargs}" - grid_arg = f"grid=grid_combo_kernels({grid_str})" + grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})" else: grid_arg = f"grid={grid}" index = V.graph.scheduler.get_current_device_or_throw().index @@ -827,7 +992,7 @@ def codegen_kernel_benchmark( result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40, fast_flush=True)" + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") @@ -882,13 +1047,22 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: wrapper = V.graph.wrapper_code assert self.dispatch_class is not None - grid = self.dispatch_class.grid(self.grids, self.x_numels_list) + dynamic_shape = self.dynamic_shape_args != [] + grid = list( + self.dispatch_class.grid(self.grids, self.x_numels_list, dynamic_shape) + ) num_kernels = len(self.sub_kernels) - min_blocks = max(self.min_x_blocks_list) * num_kernels + min_blocks = ( + max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None + ) is_sequential = self.dispatch_class is self.SequentialDispatch - if not self.enable_autotune: + if dynamic_shape: + self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) + # convert nested list to list of str + # grid = tuple("["+", ".join(pexpr(item) for item in e)+",]" for e in grid) + if not self.enable_autotune and not dynamic_shape: launch_grid = self.grid_no_autotune( - grid, num_kernels, min_blocks, is_sequential + grid, num_kernels, cast(int, min_blocks), is_sequential ) V.graph.wrapper_code.generate_kernel_call( name, @@ -906,34 +1080,33 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: num_kernels=num_kernels, min_blocks=min_blocks, is_sequential=is_sequential, + default_meta=None if self.enable_autotune else self.get_default_meta(), ) wrapper.generate_kernel_call( name, call_args, grid, V.graph.scheduler.get_current_device_or_throw().index, - cuda=True, + gpu=True, triton=True, arg_types=arg_types, grid_fn="grid_combo_kernels", grid_extra_kwargs=( f"num_kernels={num_kernels}, " f"min_blocks={min_blocks}, " - f"is_sequential={is_sequential}" + f"is_sequential={is_sequential}, " + f"default_meta={None if self.enable_autotune else self.get_default_meta()}" ), ) def grid_no_autotune( self, - grid: Tuple[Any], + grid: Union[Tuple[Any], List[Any]], num_kernels: int, min_blocks: int, is_sequential: bool, ) -> List[int]: - if "YBLOCK" in self.block_args: - meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d} - else: - meta = {"XBLOCK": self.block_size_1d} + meta = self.get_default_meta() grid_func = grid_combo_kernels( *grid, num_kernels=num_kernels, @@ -941,3 +1114,10 @@ def grid_no_autotune( is_sequential=is_sequential, ) return grid_func(meta) + + def get_default_meta(self) -> Dict[str, int]: + if "YBLOCK" in self.block_args: + meta = {"XBLOCK": self.block_size_2d, "YBLOCK": self.block_size_2d} + else: + meta = {"XBLOCK": self.block_size_1d} + return meta diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index fd260e9cb09964..fe95b92c6877e8 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -310,13 +310,13 @@ def codegen(self, code: IndentedBuffer) -> None: if self.last_seen_device_guard_index is None: if config.abi_compatible: code.writeline( - "AOTICudaStreamGuard stream_guard(stream, this->device_idx_);" + f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" ) else: code.writeline( maybe_hipify_code_wrapper( - "at::cuda::CUDAStreamGuard stream_guard(" - + "at::cuda::getStreamFromExternal(stream, this->device_idx_));" + f"{V.graph.device_ops.cpp_stream_guard()} stream_guard(" + + f"{V.graph.device_ops.cpp_getStreamFromExternal()}(stream, this->device_idx_));" ) ) else: @@ -326,10 +326,10 @@ def codegen(self, code: IndentedBuffer) -> None: else: if self.last_seen_device_guard_index is None: code.writeline( - f"AOTICudaGuard device_guard({self.device_idx});" + f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});" if config.abi_compatible else maybe_hipify_code_wrapper( - f"at::cuda::CUDAGuard device_guard({self.device_idx});" + f"{V.graph.device_ops.cpp_device_guard()} device_guard({self.device_idx});" ) ) else: @@ -534,7 +534,7 @@ def add_import_once(line: str) -> None: # intermediate tensor value printing utility self.debug_printer = DebugPrinterManager( - enable_debug_printer=config.aot_inductor.debug_intermediate_value_printer + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer ) def write_constant(self, name: str, hashed: str) -> None: @@ -545,6 +545,9 @@ def write_header(self) -> None: aot_config_comment = "" if context is not None and context.aot_graph_name is not None: aot_config_comment = f"# AOT ID: {context.aot_graph_name}" + aot_inductor_debug_utils = "" + if int(config.aot_inductor.debug_intermediate_value_printer) > 0: + aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" self.imports.splice( f""" {aot_config_comment} @@ -562,6 +565,7 @@ def write_header(self) -> None: from {async_compile.__name__} import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall + {aot_inductor_debug_utils} """, strip=True, ) @@ -581,6 +585,9 @@ def write_header(self) -> None: strip=True, ) + def include_extra_header(self, header: str): + pass + def write_kernel_autotune_defs_header(self) -> None: self.kernel_autotune_defs.splice( f""" @@ -784,11 +791,21 @@ def generate_extern_kernel_alloc(self, extern_kernel, args): def generate_extern_kernel_out( self, kernel: str, out: str, out_view: Optional[str], args: List[str] ): + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") args.append(f"out={out_view if out_view else out}") - self.writeline(f"{kernel}({', '.join(args)})") + with debug_printer_manager: + self.writeline(f"{kernel}({', '.join(args)})") def generate_user_defined_triton_kernel( - self, kernel_name, raw_args, grid, configs, triton_meta, constexprs + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, ): grid_fn, code = user_defined_kernel_grid_fn_code( kernel_name, configs, grid, wrapper=self @@ -1243,7 +1260,7 @@ def add_benchmark_harness(self, output): ) def define_kernel( - self, name: str, kernel: str, metadata: Optional[str] = None, cuda=True + self, name: str, kernel: str, metadata: Optional[str] = None, gpu=True ): metadata_comment = f"{metadata}\n" if metadata else "" body = f"\n\n{metadata_comment}{name} = {kernel}" @@ -1461,8 +1478,10 @@ def traverse(cur_kernel): ) return name, triton_meta - def generate_numel_expr(self, kernel_name: str, tree): + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" if (expr, V.graph) not in self.kernel_numel_expr: # declare expr once in each graph (scope) self.kernel_numel_expr.add((expr, V.graph)) @@ -1541,9 +1560,9 @@ def generate_save_uncompiled_kernels(self): def generate_default_grid( self, - name: str, + kernel_name: str, grid: List[Any], - cuda: bool = True, + gpu: bool = True, grid_callable: Optional[Callable[..., Any]] = None, **grid_extra_kwags, ): @@ -1616,34 +1635,44 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): ) elif isinstance(arg, (str, int, float, bool)): return str(arg) + elif isinstance(arg, list): + return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" else: - breakpoint() raise NotImplementedError(f"Unsupported type {type(arg)}") + def _grid_dim_str(self, grid_per_dim): + if isinstance(grid_per_dim, list): + return ( + "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" + ) + else: + return pexpr(grid_per_dim) + def generate_kernel_call( self, kernel_name, call_args, grid=None, device_index=None, - cuda=True, + gpu=True, triton=True, arg_types=None, raw_args=None, grid_fn: str = "grid", triton_meta=None, + autotune_configs=None, grid_extra_kwargs="", ): """ Generates kernel call code. - cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + gpu: Defines whether the backend is GPU. Otherwise the backend is CPU. triton: Defines whether the GPU backend uses Triton for codegen. Otherwise it uses the CUDA language for codegen. - Only valid when cuda == True. + Only valid when gpu == True. """ - if cuda: + if gpu: device_index, call_args_str = self.prepare_triton_kernel_call( device_index, call_args ) @@ -1654,13 +1683,19 @@ def generate_kernel_call( if grid is None: grid_str = grid_fn else: - grid_str = ", ".join(pexpr(item) for item in grid) + grid_str = ", ".join(self._grid_dim_str(item) for item in grid) if grid_extra_kwargs: grid_str = f"{grid_str}, {grid_extra_kwargs}" grid_str = f"{grid_fn}({grid_str})" - self.writeline( - f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_types, None ) + with debug_printer_manager: + self.writeline( + f"{kernel_name}.run({call_args_str}, grid={grid_str}, stream={stream_name})" + ) if ( config.triton.autotune_at_compile_time and kernel_name not in self.kernel_autotune_names @@ -1708,6 +1743,8 @@ def generate_kernel_call( grid_str = ", ".join( self.generate_example_arg_value(g, type(g)) for g in grid ) + if grid_extra_kwargs: + grid_str = f"{grid_str}, {grid_extra_kwargs}" grid_str = f"{grid_fn}({grid_str})" self.kernel_autotune_calls.writeline( @@ -2007,6 +2044,8 @@ def statically_known_int_or_none(x): # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but # the actual codegen will still generate the full expression here. return None + if isinstance(x, int): + return x val = V.graph._shape_env._maybe_evaluate_static(x) return int(val) except Exception: diff --git a/torch/_inductor/codegen/xpu/device_op_overrides.py b/torch/_inductor/codegen/xpu/device_op_overrides.py index 6eec71344ae8cc..a0c9a952991ecb 100644 --- a/torch/_inductor/codegen/xpu/device_op_overrides.py +++ b/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -15,5 +15,20 @@ def synchronize(self): def device_guard(self, device_idx): return f"torch.xpu._DeviceGuard({device_idx})" + def cpp_device_guard(self): + return "at::xpu::XPUGuard" + + def cpp_aoti_device_guard(self): + return "AOTIXpuGuard" + + def cpp_stream_guard(self): + return "at::xpu::XPUStreamGuard" + + def cpp_aoti_stream_guard(self): + return "AOTIXpuStreamGuard" + + def cpp_getStreamFromExternal(self): + return "at::xpu::getStreamFromExternal" + register_device_op_overrides("xpu", XPUDeviceOpOverrides()) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index dcad1e1bf67c98..7851e154d7386c 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -3,12 +3,14 @@ from __future__ import annotations import heapq +import logging import operator import sys from collections import defaultdict from typing import Dict, List, Set, TYPE_CHECKING import torch +from torch.multiprocessing.reductions import StorageWeakRef from . import config, ir from .dependencies import WeakDep @@ -23,6 +25,7 @@ ) +log = logging.getLogger(__name__) overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") if TYPE_CHECKING: @@ -342,6 +345,202 @@ def reorder_compute_and_comm_for_overlap( return order +def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): + """ + This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding + graph intermediates that were fsdp.copy_ into the unsharded params in the original graph. + + NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern + (or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case + where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't + remove these resize and copy ops and thus we will have worse performance there. + + In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param" + is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern + (in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed. + """ + node_list = list(graph.nodes) + + # Find all graph inputs and their resize counts + graph_input_to_resized_to_full_node_idxes = defaultdict(list) + graph_input_to_resized_to_0_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if ( + node.op == "call_function" + and node.target == torch.ops.inductor.resize_storage_bytes_.default + ): + assert ( + node.args[0].op == "placeholder" + ), f"""\ +Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} +""" + graph_input = node.args[0] + new_size = node.args[1] + if new_size > 0: + graph_input_to_resized_to_full_node_idxes[graph_input].append(idx) + else: + graph_input_to_resized_to_0_node_idxes[graph_input].append(idx) + + def check_resize_pattern(graph_input): + # Check the number of resize-to-full and resize-to-0 nodes are equal, + # and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node + # always happens before the resize-to-0 node. + # This is the precondition for being able to remove all the resize and copy nodes + # for this specific unsharded param. + resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get( + graph_input, [] + ) + resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, []) + + if not len(resized_to_full_idxes) == len(resized_to_0_idxes): + log.warning( + f""" +Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}: +{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass. +""" # noqa: G004 + ) + return False + + # Check the sequence: (resize_to_full -> resize_to_0)+ + for resize_to_full_idx, resize_to_0_idx in zip( + resized_to_full_idxes, resized_to_0_idxes + ): + if resize_to_full_idx >= resize_to_0_idx: + log.warning( + f""" +For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx} +happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param. +""" # noqa: G004 + ) + return False + return True + + # Find all eligible unsharded params and their corresponding graph intermediates. + unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: + fsdp_copy_node = node + unsharded_param = node.args[0] + assert ( + unsharded_param.op == "placeholder" + ), f""" +Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true! +Offending node: {unsharded_param}. Graph: {graph} +""" + if check_resize_pattern(unsharded_param): + unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx) + + def is_allowed_mutation(node): + return ( + node.target == torch.ops.fsdp.copy_.default + or node.target == torch.ops.inductor.resize_storage_bytes_.default + ) + + def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): + # Check whether the node is mutating any of the unsharded params or their aliases. + mutated_arg_idxes = ( + [ + i + for i, x in enumerate(node.target._schema.arguments) + if x.alias_info is not None and x.alias_info.is_write + ] + if isinstance(node.target, torch._ops.OpOverload) + else [] + ) + mutated_node_arg_storages = { + StorageWeakRef(node.args[i].meta["val"].untyped_storage()) + for i in mutated_arg_idxes + } + storages_of_unsharded_params = { + StorageWeakRef(unsharded_param.meta["val"].untyped_storage()) + for unsharded_param in unsharded_params + } + return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0 + + # Check no user mutation on any unsharded_param + for node in node_list: + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and node.target._schema.is_mutable + and not is_allowed_mutation(node) + ): + assert not is_node_mutating_unsharded_param_or_its_alias( + node, unsharded_param_to_fsdp_copy_node_idxes.keys() + ), f"""\ +User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node} +""" + + # For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`. + # + # NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input. + # e.g. + # ``` + # fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1) + # ... (use of unsharded_param_1) -> Subgraph 1 + # fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2) + # ... (use of unsharded_param_1) -> Subgraph 2 + # fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3) + # ... (use of unsharded_param_1) -> Subgraph 3 + # ``` + # We must do the replacement only within each subgraph. + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + assert fsdp_copy_node.args[0] is unsharded_param + _, replacement = fsdp_copy_node.args + # subgraph_start_idx is exclusive + subgraph_start_idx = fsdp_copy_node_idx + 1 + # subgraph_end_idx is exclusive (also intentionally don't replace args in return op) + subgraph_end_idx = ( + fsdp_copy_node_idxes[i + 1] + if i < len(fsdp_copy_node_idxes) - 1 + else len(node_list) - 1 + ) + subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx] + assert not any( + is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param]) + for node in subgraph_nodes + ), f"""\ +Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true! +Graph: {graph} +""" + for node in subgraph_nodes: + if ( + node.op == "call_function" + and unsharded_param in node.args + and node.target != torch.ops.inductor.resize_storage_bytes_.default + ): # TODO(yf225): implement replacement in kwargs + new_args = tuple( + replacement if arg is unsharded_param else arg + for arg in node.args + ) + node.args = new_args + + # Delete `fsdp.copy_(unsharded_param, Y)` nodes + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + graph.erase_node(fsdp_copy_node) + + # Delete `resize_(unsharded_param, ...)` nodes + for node in node_list: + if ( + node.op == "call_function" + and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes + ): + graph.erase_node(node) + + def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: try: import torch.distributed._composable.fsdp._fsdp_collectives @@ -509,12 +708,11 @@ def _create_group_node(snodes_to_group): name_to_fused_node, ) - # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block + # Find the "all_gather + all_gather_wait_tensor + copy_out" code block allowed_ops = { torch.ops._c10d_functional.all_gather_into_tensor_out.default, torch.ops._c10d_functional.wait_tensor.default, torch.ops.fsdp.split_with_sizes_copy.default, - torch.ops.aten.set_.source_Tensor, } find_recursive_users_of_node( ag_snode, @@ -560,7 +758,7 @@ def _create_group_node(snodes_to_group): assert wait_node_idx is not None ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) - # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode + # Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index fd8d6626deeb95..763f7cab3e2cfa 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -17,6 +17,7 @@ import torch.fx import torch.utils._pytree as pytree from functorch.compile import min_cut_rematerialization_partition +from torch._dispatch.python import enable_python_dispatcher from torch._dynamo import ( compiled_autograd, config as dynamo_config, @@ -62,6 +63,7 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.monitor import _WaitCounter +from torch.utils._ordered_set import OrderedSet from .._dynamo.backends.common import aot_autograd from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] @@ -382,7 +384,7 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor) ) - if config.comprehensive_padding and not has_gpu: + if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu: perf_hint_log.info("Skip comprehensive padding on CPU") return config.patch(comprehensive_padding=False) else: @@ -399,20 +401,22 @@ def fake_tensor_prop( The created fake mode will be returned. """ - fake_mode = detect_fake_mode(example_inputs) - if not fake_mode: - fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) - FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) - else: - ctx = ( - contextlib.nullcontext() - if not force_allow_non_fake_inputs - else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) - ) - with ctx: # type: ignore[attr-defined] - FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( - *example_inputs + # Ensure that decomps that support symbolic shapes are used + with enable_python_dispatcher(): + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) return fake_mode @@ -1037,7 +1041,9 @@ def cudagraphify_impl( Assumes inputs[static_input_idxs[i]] are always the same memory address """ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] - static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + static_input_idxs: OrderedSet[int] = OrderedSet( + remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + ) copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] assert isinstance(inputs, list) @@ -1129,6 +1135,7 @@ def compile_fx_aot( if config_patches is None else {**config_patches, "cpp_wrapper": True} ) + if ( "aot_inductor.output_path" not in config_patches and not config.aot_inductor.output_path @@ -1239,6 +1246,18 @@ def wrapper(args): return wrapper +def get_cpp_wrapper_config(): + return { + # Set autotune_at_compile_time to True as default if the option is not explicitly set + "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + else True, + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, # TODO: to be removed + "triton.store_cubin": True, + } + + def compile_fx( model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor], @@ -1261,18 +1280,8 @@ def compile_fx( if config.cpp_wrapper: with config.patch( { - "cpp_wrapper": False, - # For triton.autotune_at_compile_time, disable by default for - # FBCode, but enabled by default for OSS. - "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time - if config.is_fbcode() - else os.environ.get( - "TORCHINDUCTOR_TRITON_AUTOTUNE_AT_COMPILE_TIME", "1" - ) - == "1", - "triton.autotune_cublasLt": False, - "triton.cudagraphs": False, - "triton.store_cubin": True, + "cpp_wrapper": False, # reset to break recursive call to compile_fx + **get_cpp_wrapper_config(), } ), V.set_real_inputs(example_inputs_): inputs_ = example_inputs_ @@ -1463,16 +1472,19 @@ def bw_compiler( n.name for n in model_outputs if isinstance(n, torch.fx.Node) ) fixed = count_tangents(model) - return inner_compile( - model, - example_inputs, - static_input_idxs=list(range(fixed)), - cudagraphs=cudagraphs, - is_backward=True, - graph_id=graph_id, - boxed_forward_device_index=forward_device, - user_visible_outputs=user_visible_outputs, - ) + with config.patch( + get_cpp_wrapper_config() + ) if config.cpp_wrapper else contextlib.nullcontext(): + return inner_compile( + model, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=cudagraphs, + is_backward=True, + graph_id=graph_id, + boxed_forward_device_index=forward_device, + user_visible_outputs=user_visible_outputs, + ) # TODO: can add logging before/after the call to create_aot_dispatcher_function # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simplified::new_func diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d8af1934eea8a1..dd35755f966311 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -25,6 +25,11 @@ def autotune_remote_cache_default() -> Optional[bool]: return None +# Enable auto_functionalized_v2 (enabled by default) +enable_auto_functionalized_v2 = ( + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1" +) + # add some debug printouts debug = False @@ -60,6 +65,17 @@ def autotune_remote_cache_default() -> Optional[bool]: # sleep in inductor for testing sleep_sec_TESTING_ONLY: Optional[int] = None +# The default layout constraint for custom operators. +# This must be the name of one of the layout constraint tags +# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}), +# If the custom op does not have a layout constraint tag already +# then we assume the following applies. +custom_op_default_layout_constraint = "needs_fixed_stride_order" + +# The default layout constraint for user-defined triton kernels. +# See "The default layout constraint for custom operators" for options. +triton_kernel_default_layout_constraint = "flexible_layout" + # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" @@ -182,14 +198,7 @@ def autotune_remote_cache_default() -> Optional[bool]: # merge_splits_pass # mutate_cat_pass # split_cat_pass -pre_grad_fusion_options: Dict[str, Dict[str, Any]] = { - "batch_linear": {}, - "batch_linear_lhs": {}, - "batch_layernorm": {}, - "batch_tanh": {}, - "batch_relu": {}, - "batch_sigmoid": {}, -} +pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {} # Post grad fusion and options, set to empty dict to disable fusion. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. @@ -314,8 +323,8 @@ def autotune_remote_cache_default() -> Optional[bool]: # that can appear in the input shapes (e.g., in autotuning) unbacked_symint_fallback = 8192 -# DEPRECATED, DO NOT USE -search_autotune_cache = False +# enable searching global and local cache regardless of `max_autotune` +search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" @@ -411,6 +420,9 @@ def use_autoheuristic(name: str) -> bool: debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") +loop_ordering_after_fusion = ( + os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1" +) # For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel benchmark_epilogue_fusion = ( @@ -465,6 +477,8 @@ def use_autoheuristic(name: str) -> bool: # Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable # for all except for foreach, 2 - enable for all combo_kernel_allow_mixed_sizes = 1 +# Enable dynamic shapes for foreach kernels +combo_kernel_foreach_dynamic_shapes = False # constant folding on the joint graph joint_graph_constant_folding = True @@ -497,9 +511,18 @@ def use_autoheuristic(name: str) -> bool: # The multiprocessing start method to use for inductor workers in the codecache. # Can be "subprocess" or "fork". def decide_worker_start_method() -> str: - start_method = os.environ.get( - "TORCHINDUCTOR_WORKER_START", "fork" if is_fbcode() else "subprocess" - ) + # TODO: For internal rollout, we use a killswitch to disable the "subprocess" + # start method. The justknob check should not be performed at import, however, + # so for fbcode, we assign worker_start_method to None below and call this method + # lazily in async_compile.py. Remove this after "subprocess" rollout completes. + if "TORCHINDUCTOR_WORKER_START" in os.environ: + start_method = os.environ["TORCHINDUCTOR_WORKER_START"] + elif is_fbcode() and not torch._utils_internal.justknobs_check( + "pytorch/inductor:subprocess_parallel_compile" + ): + start_method = "fork" + else: + start_method = "subprocess" assert start_method in ( "subprocess", "fork", @@ -507,7 +530,10 @@ def decide_worker_start_method() -> str: return start_method -worker_start_method = decide_worker_start_method() +# TODO: Set start method directly after internal rollout of "subprocess". +worker_start_method: Optional[str] = ( + None if is_fbcode() else decide_worker_start_method() +) # Flags to turn on all_reduce fusion. These 2 flags should be automaticaly turned # on by DDP and should not be set by the users. @@ -563,17 +589,18 @@ def decide_compile_threads() -> int: # gemm autotuning global cache dir if is_fbcode(): - from libfb.py import parutil - try: + from libfb.py import parutil + if __package__: global_cache_dir = parutil.get_dir_path( os.path.join(__package__.replace(".", os.sep), "fb/cache") ) else: global_cache_dir = parutil.get_dir_path("fb/cache") - except ValueError: + except (ValueError, ImportError): global_cache_dir = None + else: global_cache_dir = None @@ -590,6 +617,37 @@ def decide_compile_threads() -> int: ) pad_channels_last = False +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignement=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + # Whether to treat output of the backward graph as user visible. # For user visible outputs, inductor will make sure the stride matches with eager. bw_outputs_user_visible = True @@ -671,6 +729,12 @@ def decide_compile_threads() -> int: # set unless you know what you're doing. unsafe_ignore_unsupported_triton_autotune_args: bool = False +# When True, we will check in scheduler.py _codegen that there are no "loops" +# in the call stack; that is to say, the same frame multiple times. This +# ensures that a cProfile trace to this frame will be a straight line without +# any cycles. +check_stack_no_cycles_TESTING_ONLY: bool = False + # config specific to codegen/cpp.py class cpp: @@ -768,6 +832,9 @@ class cpp: # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM. gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None) + # Whether to enable masked vectorization for the tail_loop. + enable_loop_tail_vec = True + # config specific to codegen/triton.py class triton: @@ -838,7 +905,8 @@ class triton: autotune_cublasLt = True # Tune the generated Triton kernels at compile time instead of first time they run - autotune_at_compile_time = False + # Setting to None means uninitialized + autotune_at_compile_time: Optional[bool] = None # should we stop a fusion to allow better tiling? tiling_prevents_pointwise_fusion = True @@ -918,21 +986,26 @@ class aot_inductor: os.environ.get("AOT_INDUCTOR_DEBUG_DUMP_CONSTS_BIN", "0") == "1" ) - # enable debug mode for aot inductor and it will print out more information including the intermediate tensor values, etc - # for debugging purpose - debug_intermediate_value_printer = ( - os.environ.get("AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0") == "1" + # option for debug printing/saving for intermediate tensor values for aot inductor + # 0: disable debug dumping + # 1: enable saving intermediate tensor values + # 2: enable printing intermediate tensor values + # 3: enable printing kernel names only (useful for pinpointing troublesome kernels) + debug_intermediate_value_printer = os.environ.get( + "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" ) - # filtered nodes to be printed for debug values. If not set, it will dump all debug tensor value info by default + # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2 filtered_kernel_names = os.environ.get( - "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", "default" + "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None ) # Serialized tree spec for flattening inputs + # TODO: Move this into metadata serialized_in_spec = "" # Serialized tree spec for flattening outputs + # TODO: Move this into metadata serialized_out_spec = "" # flag to decide whether to create a submodule for constant graph. @@ -943,6 +1016,11 @@ class aot_inductor: force_mmap_weights: bool = False package: bool = False + package_cpp_only: bool = False + + # Dictionary of metadata users might want to save to pass to the runtime. + # TODO: Move this somewhere else, since it's no longer really a config + metadata: Dict[str, str] = {} class cuda: @@ -1168,6 +1246,7 @@ class trace: # uses absolute path "cuda.cutlass_dir", # not relevant + "worker_start_method", "compile_threads", ] diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 72f34d32475f39..09abe579b52045 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -86,15 +86,34 @@ def _deduce_value(self, node: torch.fx.Node) -> Any: return super().run_node(node) def is_impure(self, node: torch.fx.node.Node) -> bool: + def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: + return ( + node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value] + and isinstance(node.args[0], torch.fx.Node) + and "val" in node.args[0].meta + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ) + if ( - node.target == torch.ops.prims.convert_element_type.default - and is_const_source(node.args[0], self.lifted_constants) # type: ignore[arg-type] - and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] - and node.args[1] == torch.bfloat16 + is_woq_int8_pattern(node) + or ( + node.target == torch.ops.aten.permute.default + and len(node.users) == 1 + and is_woq_int8_pattern(next(iter(node.users))) + ) + ) and is_const_source( + node.args[0], self.lifted_constants # type: ignore[arg-type] ): - # For int8_weight -> dq -> bf16_weight + # Case 1: int8_weight -> dq -> bf16_weight + # Case 2: int8_weight -> permute -> dq -> bf16_weight return True - if node.target in [ + + quant_registered = ( + getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) + is not None + ) + if quant_registered and node.target in [ torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 3ade5cf558342e..e6bc67b5289b92 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -15,6 +15,7 @@ import sys import sysconfig import warnings +from ctypes import cdll from pathlib import Path from typing import Any, List, Optional, Sequence, Tuple, Union @@ -23,6 +24,7 @@ from torch._inductor import config, exc from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA from torch._inductor.runtime.runtime_utils import cache_dir +from torch.torch_version import TorchVersion if config.is_fbcode(): @@ -57,7 +59,7 @@ def use_global_cache() -> bool: _IS_MACOS = sys.platform.startswith("darwin") _IS_WINDOWS = sys.platform == "win32" -SUBPROCESS_DECODE_ARGS = ("oem",) if _IS_WINDOWS else () +SUBPROCESS_DECODE_ARGS = ("utf-8",) if _IS_WINDOWS else () log = logging.getLogger(__name__) @@ -131,6 +133,9 @@ def check_compiler_exist_windows(compiler: str) -> None: ) except FileNotFoundError as exc: raise RuntimeError(f"Compiler: {compiler} is not found.") from exc + except subprocess.SubprocessError: + # Expected that some compiler(clang, clang++) is exist, but they not support `/help` args. + pass def get_cpp_compiler() -> str: @@ -139,7 +144,9 @@ def get_cpp_compiler() -> str: check_compiler_exist_windows(compiler) else: if config.is_fbcode(): - return build_paths.cc() + return ( + build_paths.cc() if torch.version.hip is None else build_paths.clang() + ) if isinstance(config.cpp.cxx, (list, tuple)): search = tuple(config.cpp.cxx) else: @@ -158,6 +165,13 @@ def _is_clang(cpp_compiler: str) -> bool: # Mac OS apple clang maybe named as gcc, need check compiler info. if sys.platform == "darwin": return _is_apple_clang(cpp_compiler) + elif _IS_WINDOWS: + # clang suite have many compilers, and only clang-cl is supported. + if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler): + raise RuntimeError( + "Please use clang-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return bool(re.search(r"(clang-cl)", cpp_compiler)) return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) @@ -185,6 +199,50 @@ def _is_msvc_cl(cpp_compiler: str) -> bool: return False +@functools.lru_cache(None) +def _is_intel_compiler(cpp_compiler: str) -> bool: + def _check_minimal_version(compiler_version: TorchVersion) -> None: + """ + On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. + """ + min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0" + if compiler_version < TorchVersion(min_version): + raise RuntimeError( + f"Intel Compiler error: less than minimal version {min_version}." + ) + + try: + output_msg = ( + subprocess.check_output( + [cpp_compiler, "--version"], stderr=subprocess.DEVNULL + ) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + is_intel_compiler = "Intel" in output_msg.splitlines()[0] + if is_intel_compiler: + if _IS_WINDOWS: + if re.search(r"((icx$)|(icx-cc$))", cpp_compiler): + raise RuntimeError( + "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + + # Version check + icx_ver_search = re.search(r"(\d+[.]\d+[.]\d+[.]\d+)", output_msg) + if icx_ver_search is not None: + icx_ver = icx_ver_search.group(1) + _check_minimal_version(TorchVersion(icx_ver)) + + return is_intel_compiler + except FileNotFoundError as exc: + return False + except subprocess.SubprocessError: + # --version args not support. + return False + + return False + + @functools.lru_cache(None) def is_gcc() -> bool: return _is_gcc(get_cpp_compiler()) @@ -195,6 +253,11 @@ def is_clang() -> bool: return _is_clang(get_cpp_compiler()) +@functools.lru_cache(None) +def is_intel_compiler() -> bool: + return _is_intel_compiler(get_cpp_compiler()) + + @functools.lru_cache(None) def is_apple_clang() -> bool: return _is_apple_clang(get_cpp_compiler()) @@ -622,10 +685,20 @@ def _setup_standard_sys_libs( if config.is_fbcode(): cflags.append("nostdinc") - include_dirs.append(build_paths.sleef()) - include_dirs.append(build_paths.cc_include()) - include_dirs.append(build_paths.libgcc()) - include_dirs.append(build_paths.libgcc_arch()) + # Note that the order of include paths do matter, as a result + # we need to have several branches interleaved here + if torch.version.hip is None: + include_dirs.append(build_paths.sleef()) + include_dirs.append(build_paths.openmp()) + include_dirs.append(build_paths.python()) + if torch.version.hip is not None: + include_dirs.append(build_paths.clang_include()) + include_dirs.append(build_paths.gcc_include()) + include_dirs.append(build_paths.gcc_install_tools_include()) + else: + include_dirs.append(build_paths.cc_include()) + include_dirs.append(build_paths.libgcc()) + include_dirs.append(build_paths.libgcc_arch()) include_dirs.append(build_paths.libgcc_backward()) include_dirs.append(build_paths.glibc()) include_dirs.append(build_paths.linux_kernel()) @@ -761,6 +834,51 @@ def homebrew_libomp() -> Tuple[bool, str]: return False, "" +@functools.lru_cache(None) +def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: + try: + output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( + "utf8" + ) + omp_path = os.path.join(output.rstrip(), omp_name) + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + +@functools.lru_cache(None) +def perload_icx_libomp_win(cpp_compiler: str) -> None: + def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: + try: + output = subprocess.check_output( + [cpp_compiler, f"-print-file-name={lib_name}"], + stderr=subprocess.DEVNULL, + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + omp_module = cdll.LoadLibrary(omp_path) + return True + except subprocess.SubprocessError: + pass + return False + + """ + Intel Compiler implenmented more math libraries than clang, for performance proposal. + We need preload them like openmp library. + """ + preload_list = [ + "libiomp5md.dll", # openmp + "svml_dispmd.dll", # svml library + "libmmd.dll", # libm + ] + + for lib_name in preload_list: + _load_icx_built_in_lib_by_name(cpp_compiler, lib_name) + + def _get_openmp_args( cpp_compiler: str, ) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str]]: @@ -817,11 +935,34 @@ def _get_openmp_args( # if openmp is still not available, we let the compiler to have a try, # and raise error together with instructions at compilation error later elif _IS_WINDOWS: - # /openmp, /openmp:llvm - # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ - # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 - cflags.append("openmp") - cflags.append("openmp:experimental") # MSVC CL + """ + On Windows, `clang` and `icx` have their specific openmp implenmention. + And the openmp lib is in compiler's some sub-directory. + For dynamic library(DLL) load, the Windows native APIs are `LoadLibraryA` and `LoadLibraryExA`, and their search + dependencies have some rules: + https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa#searching-for-dlls-and-dependencies + In some case, the rules may not include compiler's sub-directories. + So, it can't search and load compiler's openmp library correctly. + And then, the whole application would be broken. + + To avoid the openmp load failed, we can automatic locate the openmp binary and preload it. + 1. For clang, the function is `perload_clang_libomp_win`. + 2. For icx, the function is `perload_icx_libomp_win`. + """ + if _is_clang(cpp_compiler): + cflags.append("openmp") + libs.append("libomp") + perload_clang_libomp_win(cpp_compiler, "libomp.dll") + elif _is_intel_compiler(cpp_compiler): + cflags.append("Qiopenmp") + libs.append("libiomp5md") + perload_icx_libomp_win(cpp_compiler) + else: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL else: if config.is_fbcode(): include_dir_paths.append(build_paths.openmp()) @@ -836,6 +977,8 @@ def _get_openmp_args( # TODO: fix issue, can't find omp.h cflags.append("fopenmp") libs.append("gomp") + elif _is_intel_compiler(cpp_compiler): + cflags.append("fiopenmp") else: cflags.append("fopenmp") libs.append("gomp") @@ -1023,8 +1166,8 @@ def _transform_cuda_paths(lpaths: List[str]) -> None: break -def get_cpp_torch_cuda_options( - cuda: bool, +def get_cpp_torch_device_options( + device_type: str, aot_mode: bool = False, compile_only: bool = False, ) -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: @@ -1040,15 +1183,17 @@ def get_cpp_torch_cuda_options( and "CUDA_HOME" not in os.environ and "CUDA_PATH" not in os.environ ): - os.environ["CUDA_HOME"] = build_paths.cuda() + os.environ["CUDA_HOME"] = ( + build_paths.rocm() if torch.version.hip else build_paths.cuda() + ) _set_gpu_runtime_env() from torch.utils import cpp_extension - include_dirs = cpp_extension.include_paths(cuda) - libraries_dirs = cpp_extension.library_paths(cuda) + include_dirs = cpp_extension.include_paths(device_type) + libraries_dirs = cpp_extension.library_paths(device_type) - if cuda: + if device_type == "cuda": definations.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") if torch.version.hip is not None: @@ -1056,15 +1201,17 @@ def get_cpp_torch_cuda_options( libraries += ["amdhip64"] else: libraries += ["c10_hip", "torch_hip"] - definations.append(" __HIP_PLATFORM_AMD__") + definations.append(" __HIP_PLATFORM_AMD__") else: if config.is_fbcode(): libraries += ["cuda"] else: - if config.is_fbcode(): - libraries += ["cuda"] - else: - libraries += ["c10_cuda", "cuda", "torch_cuda"] + libraries += ["c10_cuda", "cuda", "torch_cuda"] + + if device_type == "xpu": + definations.append(" USE_XPU") + cflags += ["fsycl"] + libraries += ["c10_xpu", "sycl", "ze_loader", "torch_xpu"] if aot_mode: if config.is_fbcode(): @@ -1073,7 +1220,7 @@ def get_cpp_torch_cuda_options( cpp_prefix_include_dir = [f"{os.path.dirname(cpp_prefix_path())}"] include_dirs += cpp_prefix_include_dir - if cuda and torch.version.hip is None: + if device_type == "cuda" and torch.version.hip is None: _transform_cuda_paths(libraries_dirs) if config.is_fbcode(): @@ -1082,7 +1229,7 @@ def get_cpp_torch_cuda_options( else: include_dirs.append(os.path.join(build_paths.cuda(), "include")) - if aot_mode and cuda: + if aot_mode and device_type == "cuda": if torch.version.hip is None: if not compile_only: # Only add link args, when compile_only is false. @@ -1099,18 +1246,18 @@ def get_cpp_torch_cuda_options( ) -class CppTorchCudaOptions(CppTorchOptions): +class CppTorchDeviceOptions(CppTorchOptions): """ This class is inherited from CppTorchOptions, which automatic contains base cxx build options and torch common build options. And then it will - maintains cuda device related build args. + maintains cuda/xpu device related build args. """ def __init__( self, vec_isa: VecISA = invalid_vec_isa, include_pytorch: bool = False, - cuda: bool = True, + device_type: str = "cuda", aot_mode: bool = False, compile_only: bool = False, use_absolute_path: bool = False, @@ -1127,33 +1274,37 @@ def __init__( use_mmap_weights=use_mmap_weights, extra_flags=extra_flags, ) + if device_type == "xpu": + from torch.utils.cpp_extension import _join_sycl_home + + self._compiler = _join_sycl_home("bin", "icpx") - cuda_definations: List[str] = [] - cuda_include_dirs: List[str] = [] - cuda_cflags: List[str] = [] - cuda_ldflags: List[str] = [] - cuda_libraries_dirs: List[str] = [] - cuda_libraries: List[str] = [] - cuda_passthough_args: List[str] = [] + device_definations: List[str] = [] + device_include_dirs: List[str] = [] + device_cflags: List[str] = [] + device_ldflags: List[str] = [] + device_libraries_dirs: List[str] = [] + device_libraries: List[str] = [] + device_passthough_args: List[str] = [] ( - cuda_definations, - cuda_include_dirs, - cuda_cflags, - cuda_ldflags, - cuda_libraries_dirs, - cuda_libraries, - cuda_passthough_args, - ) = get_cpp_torch_cuda_options( - cuda=cuda, aot_mode=aot_mode, compile_only=compile_only + device_definations, + device_include_dirs, + device_cflags, + device_ldflags, + device_libraries_dirs, + device_libraries, + device_passthough_args, + ) = get_cpp_torch_device_options( + device_type=device_type, aot_mode=aot_mode, compile_only=compile_only ) - _append_list(self._definations, cuda_definations) - _append_list(self._include_dirs, cuda_include_dirs) - _append_list(self._cflags, cuda_cflags) - _append_list(self._ldflags, cuda_ldflags) - _append_list(self._libraries_dirs, cuda_libraries_dirs) - _append_list(self._libraries, cuda_libraries) - _append_list(self._passthough_args, cuda_passthough_args) + _append_list(self._definations, device_definations) + _append_list(self._include_dirs, device_include_dirs) + _append_list(self._cflags, device_cflags) + _append_list(self._ldflags, device_ldflags) + _append_list(self._libraries_dirs, device_libraries_dirs) + _append_list(self._libraries, device_libraries) + _append_list(self._passthough_args, device_passthough_args) self._finalize_options() diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 7cadacfdf1c47b..bc4838e5f16855 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -296,9 +296,9 @@ def _check_and_append_supported_isa( if Arch != "x86_64" and Arch != "AMD64": return supported_isa - avx2 = torch.cpu._is_cpu_support_avx2() - avx512 = torch.cpu._is_cpu_support_avx512() - amx_tile = torch.cpu._is_cpu_support_amx_tile() + avx2 = torch.cpu._is_avx2_supported() + avx512 = torch.cpu._is_avx512_supported() + amx_tile = torch.cpu._is_amx_tile_supported() _check_and_append_supported_isa(supported_isa, avx2, "avx2") _check_and_append_supported_isa(supported_isa, avx512, "avx512") diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 422be37e78a0d2..5a33de0e366899 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -1578,7 +1578,7 @@ def _allocate_and_copy_recording_inputs( self, inputs: List[InputType] ) -> List[Union[torch.Tensor, int]]: """ - Allocate inputs for non static, non cudagraph managraphed managed tensors in the memory pool + Allocate inputs for non static, non cudagraph managed tensors in the memory pool and copy over the tensor values. """ diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index df26bba94190d0..868833a425be4b 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -111,6 +111,7 @@ def func1(*args: Any) -> int: FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) buf_to_fx_node = {} + node_to_fx_node = {} graph = torch.fx.Graph() first_node = None @@ -162,10 +163,9 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) - if isinstance(snode, FusedSchedulerNode): - for x in snode.snodes: - buf_to_fx_node[x.get_name()] = fx_node - buf_to_fx_node[name] = fx_node + node_to_fx_node[name] = fx_node + for buf in snode.get_outputs(): + buf_to_fx_node[buf.get_name()] = fx_node if first_node is None: first_node = fx_node @@ -175,7 +175,7 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: name = snode.get_name() deps = snode.read_writes.reads - fx_node = buf_to_fx_node[name] + fx_node = node_to_fx_node[name] new_args = [] for dep in deps: if dep.name in buf_to_fx_node: @@ -184,6 +184,8 @@ def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: with graph.inserting_before(first_node): dep_node = graph.placeholder(dep.name) buf_to_fx_node[dep.name] = dep_node + if dep_node == fx_node: # to avoid cycles + continue new_args.append(dep_node) fx_node.args = tuple(new_args) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d369c2a664459a..0e067395f80712 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -48,6 +48,7 @@ inductor_decompositions = get_decompositions( [ aten._adaptive_avg_pool2d_backward, + aten.addmv, aten.arange, aten.bitwise_and_, aten.bitwise_or_, @@ -650,10 +651,10 @@ def wrapped_quantized_linear( out_zero_point: torch.Tensor, out_channel: int, ) -> torch.Tensor: - packed_weight = torch.ops._quantized.wrapped_linear_prepack( + packed_weight = torch.ops._quantized._wrapped_linear_prepack( weight, weight_scale, weight_zero_point, bias ) - return torch.ops._quantized.wrapped_quantized_linear_prepacked( + return torch.ops._quantized._wrapped_quantized_linear_prepacked( input, input_scale, input_zero_point, diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 5fab371d4406de..643bf686bd8f19 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -1,12 +1,11 @@ # mypy: allow-untyped-defs import abc -import collections import dataclasses import itertools import logging import re import typing -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from unittest.mock import patch import sympy @@ -55,6 +54,9 @@ def has_unbacked_symbols(self) -> bool: def is_contiguous(self) -> bool: pass + def normalize_with_stride_order(self, prefix="t"): + return self + @dataclasses.dataclass(frozen=True) class MemoryDep(Dep): @@ -67,12 +69,87 @@ class MemoryDep(Dep): def __repr__(self) -> str: return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}, {self.mode})" + @property + def num_vars(self): + return len(self.var_names) + + def decide_loop_order_to_match(self, other): + """ + Can return None if not able to decide loop orders. + """ + assert self.num_vars == other.num_vars + + # ignore broadcast for now since broadcast causes extra 0 strides + # which makes it hard to decide the correct loop orders. + if self.num_vars != len(self.index.free_symbols): + return None + if other.num_vars != len(other.index.free_symbols): + return None + + # bail out if any size is 0 or 1 + # For size == 0, it's an empty tensor, any strides for that dimension + # are equivalent. Skip for simplicity and it may not matter that much. + # + # For size == 1, it cause cause tie for strides of different dimensions. + # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder + # we can dependencies.index_vars_squeeze which should already sqeeuze + # the size == 1 dimensions. + if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): + return None + + # Extract strides for both expression + self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) + + # Even if the shape contains no 0/1, some complex index expression may + # still have duplicate stride values. Here is an example: + # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 + # We don't reorder the loop for these cases for now, but in theory + # we could improve the algorithm to detect the correct loop orders. + if len(set(self_strides)) != len(self_strides) or len( + set(other_strides) + ) != len(other_strides): + log.debug( + "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", + self, + other, + self_strides, + other_strides, + ) + return None + + # May hanppen if self and other are as follows + # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) + # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) + if set(self_strides) != set(other_strides): + return None + + stride_to_index = {s: i for i, s in enumerate(self_strides)} + order = [] + for s in other_strides: + order.append(stride_to_index[s]) + + assert set(order) == set(range(0, self.num_vars)) + return order + def get_offset(self): """ Return the offset by setting every variable to be 0. """ return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) + def normalize(self) -> "MemoryDep": + """ + Normalize by merging loops. The different to normalize_with_stride_order is, + this method does not reorder loops while normalize_with_stride_order reorder + loops based on stride order. + """ + return MemoryDep( + self.name, + *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] + self.mode, + ) + def normalize_with_stride_order(self, prefix="t"): r""" Used to decide if two MemoryDep does not equal due to different loop orders. @@ -278,9 +355,6 @@ class ReadWrites: index_exprs: OrderedSet[IndexExprDep] range_vars: Optional[List[sympy.Expr]] = None var_ranges: Optional[VarRanges] = None - op_counts: typing.Counter[str] = dataclasses.field( - default_factory=collections.Counter - ) def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": return ReadWrites( @@ -289,39 +363,32 @@ def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": self.index_exprs, self.range_vars, self.var_ranges, - op_counts=self.op_counts, ) - def with_read(self, dep: Dep) -> "ReadWrites": - assert isinstance(dep, (WeakDep, StarDep)) + def with_read(self, dep: Union[Dep, Set[Dep]]) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep, set)) + if not isinstance(dep, set): + dep = {dep} return ReadWrites( - OrderedSet.union(self.reads, [dep]), + OrderedSet.union(self.reads, dep), self.writes, self.index_exprs, self.range_vars, self.var_ranges, - op_counts=self.op_counts, ) def merge(self, other: "ReadWrites"): reads = OrderedSet.union(self.reads, other.reads) writes = OrderedSet.union(self.writes, other.writes) index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) - op_counts = collections.Counter(self.op_counts) - op_counts.update(other.op_counts) - return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts) + return ReadWrites(reads - writes, writes, index_exprs) @staticmethod def merge_list(read_writes: List["ReadWrites"]): all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) - - op_counts: typing.Counter[Any] = collections.Counter() - for rw in read_writes: - op_counts.update(rw.op_counts) - - return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts) + return ReadWrites(all_reads, all_writes, all_index_exprs) def remove_reads(self, rem_reads): return ReadWrites( @@ -330,7 +397,6 @@ def remove_reads(self, rem_reads): self.index_exprs, self.range_vars, self.var_ranges, - op_counts=self.op_counts, ) def reads_and_writes(self): @@ -358,31 +424,31 @@ def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: self._writes: OrderedSet[MemoryDep] = OrderedSet() self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() self._var_ranges: VarRanges = var_ranges - self._normalize: bool = normalize + self._should_normalize: bool = normalize - def canonicalize( - self, index: sympy.Expr - ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: - if not self._normalize: - sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] - var_names = tuple( - k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1 - ) - sizes = tuple(v for v in sizes if v != 1) - return index, var_names, sizes # type: ignore[return-value] + @staticmethod + def drop_unused_symbols(index, var_names, sizes): + """ + Reduction has last (reduced) dim in its sizes, but + downstream users won't. Normalize this away. + """ + if not isinstance(index, sympy.Expr): + # index can be an int + return + free_symbols = index.free_symbols + while var_names and var_names[-1] not in free_symbols: + var_names.pop() + sizes.pop() + @classmethod + def _normalize( + cls, index: sympy.Expr, var_ranges: VarRanges + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: # Try to further simplify the indexes even if simplify_loops didn't # convert it to the simplest form because of the interference from # different indexing formulas. - free_symbols = index.free_symbols - var_ranges = { - k: V.graph.sizevars.simplify(v) - for k, v in self._var_ranges.items() - # TODO(jansel): explore this further normalization - # if k in free_symbols - } index_vars = [*var_ranges.keys()] - sizes = tuple(var_ranges.values()) + sizes = tuple(var_ranges.values()) # type: ignore[assignment] new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( index_vars, sizes, @@ -397,14 +463,28 @@ def canonicalize( new_vars = [*new_vars.keys()] new_sizes = [*new_sizes] - free_symbols = index.free_symbols - while new_vars and new_vars[-1] not in free_symbols: - # Reduction has last (reduced) dim in its sizes, but - # downstream users won't. Normalize this away. - new_vars.pop() - new_sizes.pop() + cls.drop_unused_symbols(index, new_vars, new_sizes) return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + def canonicalize( + self, index: sympy.Expr + ) -> Tuple[sympy.Expr, Tuple[sympy.Symbol, ...], Tuple[sympy.Expr, ...]]: + if not self._should_normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] + sizes = [v for v in sizes if v != 1] + + self.drop_unused_symbols(index, var_names, sizes) + + return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + return self._normalize(index, var_ranges) + def load(self, name: str, index: sympy.Expr) -> str: self._reads.add(MemoryDep(name, *self.canonicalize(index))) return f"load({name}, {sympy_str(index)})" @@ -436,25 +516,11 @@ def bucketize( return f"bucketize({values}, {offsets_name}, {sympy_str(offsets_size)}, {indexing_dtype}, {right})" -class _OpCounter: - """Shim to count how many times each op is used""" - - def __init__(self, inner) -> None: - super().__init__() - self.parent_handler = inner - self._op_counts: typing.Counter[Any] = collections.Counter() - - def __getattr__(self, name): - self._op_counts[name] += 1 - return getattr(self.parent_handler, name) - - class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: parent_handler = _RecordLoadStoreInner( var_ranges=var_ranges, normalize=normalize ) - parent_handler = _OpCounter(parent_handler) super().__init__(parent_handler=parent_handler) @@ -497,25 +563,57 @@ def extract_read_writes( *argsizes: Tuple[sympy.Expr, ...], normalize: bool = False, prefix: str = "d", + hidden_args=(), ): args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) - rw = RecordLoadStore(var_ranges, normalize=normalize) - with V.set_ops_handler(rw): - fn(*args) + + from .loop_body import LoopBody, MemoryUsageType + + if isinstance(fn, LoopBody): + # Fast path to avoid tracing when we already have a LoopBody + inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) + name_to_index = fn.indexing_from_args([*args, *hidden_args]) + if fn.indirect_vars: + # mimic the `tmpX` naming tracing gives us + repl = {v: sympy.Symbol(f"tmp{i}") for i, v in enumerate(fn.indirect_vars)} + name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} + for entry in fn.memory_usage[MemoryUsageType.LOAD]: + inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.STORE]: + inner.store( + entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type] + ) + for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: + inner.store_reduction( + entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type] + ) + for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: + inner.index_expr(name_to_index[entry.index_name], None) + for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: + inner.bucketize( + None, entry.buffer_name, name_to_index[entry.index_name], None, None # type: ignore[arg-type] + ) + # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped + else: + # Slow path tracing the function + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args, *hidden_args) + inner = rw.parent_handler if normalize: range_vars = [] # Number of vars could differ due to normalization else: - range_vars = list(itertools.chain.from_iterable(args)) + range_vars = [*itertools.chain.from_iterable(args)] - inner = rw.parent_handler.parent_handler return ReadWrites( OrderedSet(inner._reads), OrderedSet(inner._writes), inner._index_exprs, range_vars, var_ranges, - rw.parent_handler._op_counts, ) diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index c360ae9ec6d87f..bd1e0736d510fa 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -34,10 +34,7 @@ def mark_mixed_dtype_conv(conv): conv_user = next(iter(conv_user.users.keys())) - if not ( - conv_user.target == prims.convert_element_type.default - and conv_user.args[1] == conv_dtype - ): + if conv_user.target != prims.convert_element_type.default: return conv.meta["_allow_conv_mixed_dtype_folding"] = conv_dtype diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 55b2690edf9f02..5311c9789fa59e 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -707,15 +707,24 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): torch.baddbmm, args=(unsqueeze_biases, stack_inputs, transpose_weight), ) - bmm.meta["example_value"] = torch.baddbmm( - unsqueeze_biases.meta["example_value"], - stack_inputs.meta["example_value"], - transpose_weight.meta["example_value"], - ) - bmm_meta = bmm.meta["example_value"] + try: + # it will have runtime error to broadcast when it has dynamic shape included + # in the meta data, so we need to skip the update meta data + bmm.meta["example_value"] = torch.baddbmm( + unsqueeze_biases.meta["example_value"], + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + except Exception as e: + log.debug( + f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 + ) + bmm_meta = None bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) - bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) + if bmm_meta is not None: + bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) for i, linear in enumerate(batch_nodes): with graph.inserting_after(bmm): getitem = graph.call_function(operator.getitem, args=(bmm, i)) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 416125a8f90ffe..d526f1bb6a64de 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -14,6 +14,7 @@ from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.multiprocessing.reductions import StorageWeakRef +from ...utils._ordered_set import OrderedSet from .. import config from ..pattern_matcher import ( CallFunction, @@ -475,6 +476,59 @@ def joint_graph_passes(graph: torch.fx.GraphModule): return graph +@register_graph_pattern( + CallFunction( + torch.ops.prims.iota.default, + KeywordArg("length"), + start=KeywordArg("start"), + step=KeywordArg("step"), + dtype=KeywordArg("dtype"), + device=KeywordArg("device"), + requires_grad=KeywordArg("requires_grad"), + ), + pass_dict=patterns, +) +def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): + """ + Eager supports: + + aten.index(cuda_tensor, torch.arange(..., device="cpu")) + + But this results in an implicit host-device-copy and breaks cudagraphs. + Rewrite the arange to use CUDA. + """ + (node,) = match.nodes + user_devices: OrderedSet[torch.device] = OrderedSet() + for user in node.users: + if ( + user.op == "call_function" + and user.target in (aten.index.Tensor, aten.index_put.default) + and hasattr(user.meta.get("val"), "device") + ): + user_devices.add(user.meta["val"].device) # type: ignore[union-attr] + else: + return # bail out + + if len(user_devices) == 1 and "val" in node.meta: + (user_device,) = user_devices + if device.type != user_device.type: + repl = match.graph.call_function( + torch.ops.prims.iota.default, + (length,), + { + "start": start, + "step": step, + "dtype": dtype, + "device": user_device, + "requires_grad": requires_grad, + }, + ) + repl.meta.update(node.meta) + repl.meta["val"] = repl.meta["val"].to(user_device) + node.replace_all_uses_with(repl) + match.erase_nodes() + + @register_graph_pattern( CallFunction( torch.ops.prims.convert_element_type.default, @@ -498,7 +552,7 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp ) repl.meta.update(node.meta) node.replace_all_uses_with(repl) - match.erase_nodes(graph) + match.erase_nodes() @register_graph_pattern( @@ -507,12 +561,11 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp ) def pointless_view(match: Match, arg, size): """Remove no-op view""" - graph = match.graph node = match.output_node() arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] if size == arg_size: node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] - match.erase_nodes(graph) + match.erase_nodes() # When softmax is used with temperature or other scaling, we get the pattern diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 2ed93c2cd95d48..156760a68e7e69 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -835,7 +835,7 @@ def linear_bias_pattern(match, *args): ) repl.meta.update(add_node.meta) add_node.replace_all_uses_with(repl) - match.erase_nodes(graph) + match.erase_nodes() def _is_packable_mkldnn_rnn_layer(match): lstm_node = match.output_node() diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 8d2f0d5c7caad7..064c6de94aed61 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -22,6 +22,7 @@ from .. import config, ir, pattern_matcher from ..codegen.common import BackendFeature, has_backend_feature +from ..comms import remove_fsdp2_unsharded_param_graph_input_usage from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage from ..lowering import lowerings as L from ..pattern_matcher import ( @@ -76,6 +77,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): The IR here has been normalized and functionalized. """ + if not torch._dynamo.config.skip_fsdp_hooks: + remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) + if config.dce: # has some issues with mutation in inference mode gm.graph.eliminate_dead_code() @@ -207,6 +211,9 @@ def register_lowering_pattern(pattern, extra_check=_return_true, pass_number=1): def is_valid_mm_plus_mm(match: Match): + if not torch._inductor.utils.use_max_autotune(): + return False + *b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape *b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape if k1 != k2: @@ -798,13 +805,20 @@ def remove_noop_ops(graph: torch.fx.Graph): def decompose_auto_functionalized(graph): + """Decomposes auto_functionalized and triton_kernel_wrapper_functional + nodes into clones and the underlying mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ graph_pass = PatternMatcherPass() @register_graph_pattern( CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), pass_dict=graph_pass, ) - def replacement(match: Match, *args, **kwargs): + def _(match: Match, *args, **kwargs): from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense only_clone_these_tensors = tuple( @@ -822,12 +836,68 @@ def decomp(*flat_args): match.replace_by_example(decomp, flat_args, run_functional_passes=False) + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional_dense, + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized_v2_dense, + ) + + only_clone_these_bases = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return auto_functionalized_v2_dense(*args, only_clone_these_bases, **kwargs) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + graph_pass.apply(graph) + for node in graph.find_nodes( op="call_function", target=torch.ops.higher_order.auto_functionalized ): raise AssertionError("auto_functionalized was not removed") + for node in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 + ): + raise AssertionError("auto_functionalized_v2 was not removed") + + for node in graph.find_nodes( + op="call_function", + target=torch.ops.higher_order.triton_kernel_wrapper_functional, + ): + raise AssertionError("triton_kernel_wrapper_functional was not removed") + @register_lowering_pattern( CallFunction( diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index eef058ab0b5aa2..bca3361962b07c 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -100,6 +100,10 @@ def stack_to_unsqueeze_pass(graph): return None +def merge_concats_pass(graph): + return None + + @init_once_fakemode def lazy_init(): from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401 @@ -180,6 +184,12 @@ def shape_prop(mod) -> None: example_inputs, "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass", ) + pass_execution_and_save( + merge_concats_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)] Apply merge_concats_pass", + ) pass_execution_and_save( fuse_split_linear_add_pass.apply, gm, diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 6e17e2ff456fc6..3c918d480704e9 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1561,6 +1561,27 @@ def _register_woq_mm_int8_pattern3(): _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) +def _register_woq_mm_int8_pattern4(): + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg("weight"), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + def _register_quantization_lowerings(): _register_quantization_unary_fusion() _register_quantization_binary_fusion() @@ -1573,6 +1594,7 @@ def _register_woq_lowerings(): _register_woq_mm_int8_pattern1() _register_woq_mm_int8_pattern2() _register_woq_mm_int8_pattern3() + _register_woq_mm_int8_pattern4() def _is_valid_dequant_promotion_pattern(dtype=torch.float32): diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index a7cd01dc0043e3..59706134f85fea 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -11,7 +11,7 @@ kernel_side_table, triton_kernel_wrapper_functional, ) -from torch._inductor import inductor_prims +from torch._inductor import config, inductor_prims from torch._inductor.fx_utils import get_node_storage, is_node_realized from torch._inductor.lowering import ( inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, @@ -467,6 +467,7 @@ def can_inplace(node, mutated_arg): if get_node_storage(mutated_arg) is None: return False shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + if mutated_arg.op in ("placeholder", "get_attr"): # Get the first copy_ node that mutates the mutated_arg. copy_node = copy_nodes.get(mutated_arg, None) @@ -482,6 +483,9 @@ def can_inplace(node, mutated_arg): return True elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): + # This should never happen in auto_functionalize_v2 non-inference mode, + # since all mutated_arg are bases. + # If mutated arg is view of any of the inputs of the graph, # do not allow for inplacing. # This would require more sophisticated algorithm to handle @@ -491,9 +495,30 @@ def can_inplace(node, mutated_arg): node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg ) + def log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ): + log.info( + "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " + "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " + "memory usage and performance.", + node_name, + old_tensors_to_clone, + tensors_to_clone, + possibly_missed_reinplacing_opportunities, + ) + torch._dynamo.utils.counters["inductor"][ + "possibly_missed_reinplacing_opportunities" + ] += len(possibly_missed_reinplacing_opportunities) + replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {} - def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs, node_name): + def reinplace_and_refine_tensors_to_clone( + old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False + ): tensors_to_clone: List[str] = [] storage_of_reinplaced_args = set() possibly_missed_reinplacing_opportunities = [] @@ -507,6 +532,7 @@ def tensor_with_same_storage_already_reinplaced(arg): for arg in old_tensors_to_clone: assert arg in kwargs + mutated_arg = kwargs[arg] # Let's say we have: @@ -523,12 +549,18 @@ def tensor_with_same_storage_already_reinplaced(arg): mutated_arg ) if should_attempt_reinplace and can_inplace(node, mutated_arg): + # In general, we probably do not need those optimizations. copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) if copy_node is not None: replace_dict[copy_node] = copy_node.args[0] - for user in node.users: - if user.target == operator.getitem and user.args[1] == arg: - replace_dict[user] = mutated_arg + if not auto_functionalize_v2: + for user in node.users: + # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to + # output atindex size(out)+i. + # This used to compare string with integers before for auto_functionalize_v2. Not sure + # if it was needed for inplaceable_triton_ops? + if user.target == operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg if isinstance(mutated_arg, (list, tuple)): for a in mutated_arg: @@ -540,18 +572,12 @@ def tensor_with_same_storage_already_reinplaced(arg): possibly_missed_reinplacing_opportunities.append(arg) tensors_to_clone.append(arg) - log.info( - "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " - "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " - "memory usage and performance.", + log_inplace_results( node_name, old_tensors_to_clone, tensors_to_clone, possibly_missed_reinplacing_opportunities, ) - torch._dynamo.utils.counters["inductor"][ - "possibly_missed_reinplacing_opportunities" - ] += len(possibly_missed_reinplacing_opportunities) return tensors_to_clone for node in graph.nodes: @@ -565,17 +591,37 @@ def tensor_with_same_storage_already_reinplaced(arg): if copy_node is not None: replace_dict[copy_node] = copy_node.args[0] node.target = inplaceable_op.inplace_op + elif node.target == torch.ops.higher_order.auto_functionalized_v2: + _mutable_op = node.args[0] + kwargs = node.kwargs + + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: List[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + auto_functionalize_v2=True, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone elif node.target == torch.ops.higher_order.auto_functionalized: _mutable_op = node.args[0] - from torch._higher_order_ops.auto_functionalize import get_mutable_arg_names + from torch._higher_order_ops.auto_functionalize import get_mutable_args - tensors_to_clone = get_mutable_arg_names(_mutable_op) + tensors_to_clone, _ = get_mutable_args(_mutable_op) # Don't try to reinplace Optional[Tensor] args that are None. tensors_to_clone = [ t for t in tensors_to_clone if node.kwargs[t] is not None ] tensors_to_clone = reinplace_and_refine_tensors_to_clone( - tensors_to_clone, node.kwargs, _mutable_op._name + tensors_to_clone, + node.kwargs, + _mutable_op._name, + auto_functionalize_v2=False, ) # Stash the metadata. There is a pass later on where we decompose @@ -591,7 +637,14 @@ def tensor_with_same_storage_already_reinplaced(arg): if isinstance(kernel, JITFunction): kernel_name = kernel.fn.__name__ elif isinstance(kernel, Autotuner): - kernel_name = kernel.base_fn.__name__ + if config.is_fbcode(): + # Autotuner has different implementations for AMD and NV + if torch.version.hip is None: + kernel_name = kernel.base_fn.__name__ + else: + kernel_name = kernel.fn.__name__ + else: + kernel_name = kernel.base_fn.__name__ else: raise AssertionError("Unknown triton kernel type") diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index d0098a535ee295..f850ecf6008c94 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -238,7 +238,9 @@ def remove_split_with_size_one(match: Match, *args, **kwargs): # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. return # remove the dummy split whose split sections size is one - if len(split_sections) == 1: + # theoretically nodes with no users should be removed, but we have seen the corner case + # thus we add its uers check to walk around the StopIteration error. + if len(split_sections) == 1 and len(split_node.users.keys()) > 0: # find the grand children of the split_node next_users = find_next_users(split_node) user = next(iter(split_node.users.keys())) @@ -447,6 +449,12 @@ def normalize_reshape_default(match: Match, *args, **kwargs): return reshape_input = get_arg_value(reshape_node, 0) + from torch.fx.experimental.symbolic_shapes import free_symbols + + if free_symbols(reshape_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", reshape_node) + return + with match.graph.inserting_after(reshape_node): new_reshape_node = match.graph.call_function( torch.reshape, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 514c5b2f9b2b97..6b44723a342def 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -124,7 +124,7 @@ def log_module_code(*args: Any, **kwargs: Any) -> None: pass -def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool: +def supported_dtype_of_cpp_wrapper(dtype: torch.device, device_type: str) -> bool: supported_dtype = { torch.float32, torch.float64, @@ -140,7 +140,7 @@ def supported_dtype_of_cpp_wrapper(dtype: torch.device, cuda: bool) -> bool: torch.complex128, torch.float16, } - if cuda: + if device_type == "cuda": supported_dtype.add(torch.float8_e4m3fn) supported_dtype.add(torch.float8_e5m2) supported_dtype.add(torch.float8_e4m3fnuz) @@ -235,9 +235,7 @@ def _get_overload_packet( op = _get_overload_packet(cur) if not op: continue - if op in ops_dislike_padding or ( - user_visible_outputs and cur.name in user_visible_outputs - ): + if op in ops_dislike_padding: cur.meta["dislike_padding"] = True if cur.meta.get("dislike_padding", False): @@ -248,6 +246,13 @@ def _get_overload_packet( continue if prior_op not in ops_like_padding: prior.meta["dislike_padding"] = True + # We only want to mark output nodes. So, move it after the above prior nodes process. + if ( + not config.pad_outputs + and user_visible_outputs + and cur.name in user_visible_outputs + ): + cur.meta["dislike_padding"] = True class GraphLowering(torch.fx.Interpreter): @@ -360,7 +365,7 @@ def __init__( self.device_idxs: OrderedSet[int] = ( const_module.device_idxs if const_module else OrderedSet() ) - self.cuda = False + self.device_type = "cpu" self.buffers: List[ir.Buffer] = [] self.operations: List[ir.Operation] = [] self.const_output_index: Dict[str, int] = ( @@ -787,8 +792,8 @@ def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: name = self.qualify_name(f"buf{len(self.buffers)}") self.buffers.append(buffer) self.name_to_buffer[name] = buffer - # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 if ( + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements()) and buffer.get_device() is not None ): @@ -849,7 +854,6 @@ def get_original_value_of_constant(self, name: str) -> torch.Tensor: def allocate_non_dup_const_name( self, name: Optional[str], data: Union[Tensor] ) -> str: - orig_name = name if not config.aot_inductor.use_runtime_constant_folding: for constant_name, value in self.constants.items(): if ( @@ -866,7 +870,7 @@ def allocate_non_dup_const_name( if name is None: name = f"constant{len(self.constants)}" - assert name is not None + orig_name = name if name[0].isdigit(): name = f"constant_{name}" name = self.qualify_name(name) @@ -953,7 +957,8 @@ def placeholder( ) self.graph_inputs[target] = tensor self.graph_inputs_original[target] = tensor.data.data - self.add_device_info(example.device) + if self.current_node.users: # cudagraphs should work with an unused CPU input + self.add_device_info(example.device) # Note: [Input Alignment handling in Inductor] # Alignment matters for generating efficient code. Some operations, @@ -1232,10 +1237,30 @@ def propagate_mutation( If fx_node mutates any of new_args/new_kwargs, and they are different from old_args/old_kwargs, then we need to update the original tensor. """ - assert isinstance(fx_node.target, torch._ops.OpOverload) assert len(old_args) == len(new_args) assert len(old_kwargs) == len(new_kwargs) + if fx_node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation: + kwargs = fx_node.kwargs["kwargs"] + assert isinstance(kwargs, dict) + mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors( + old_kwargs["kernel_idx"], + old_kwargs["constant_args_idx"], + { + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in kwargs.items() + }, + ) + for name in mutated: + old_arg = old_kwargs["kwargs"][name] + new_arg = new_kwargs["kwargs"][name] + if old_arg is new_args: + continue + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) + return + + assert isinstance(fx_node.target, torch._ops.OpOverload) + def maybe_propagate( schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode ) -> None: @@ -1297,6 +1322,25 @@ def debug(msg: str) -> None: # if they do, and if the target is mutable, then we need to # write the new values back into the original inputs. self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + elif ( + n.op == "call_function" + and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation + and config.triton_kernel_default_layout_constraint != "flexible_layout" + ): + debug("user_defined_triton_kernel_layout_constraints") + if ( + config.triton_kernel_default_layout_constraint + == "needs_fixed_stride_order" + ): + old_args = args # type: ignore[possibly-undefined] + old_kwargs = kwargs # type: ignore[possibly-undefined] + args, kwargs = torch._inductor.lowering.constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + else: + raise RuntimeError( + f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}" + ) elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride @@ -1358,9 +1402,8 @@ def debug(msg: str) -> None: strides = n.meta["val"].stride() if len(strides): allow_padding = ( - n.name not in self.user_visible_outputs - and not is_input_for_as_strided - ) + config.pad_outputs or n.name not in self.user_visible_outputs + ) and not is_input_for_as_strided dense = torch._prims_common.is_non_overlapping_and_dense( n.meta["val"] ) @@ -1632,14 +1675,10 @@ def validate_can_generate_cpp_wrapper(self) -> None: ): dtype = may_get_constant_buffer_dtype(value) - if not supported_dtype_of_cpp_wrapper(dtype, self.cuda): + if not supported_dtype_of_cpp_wrapper(dtype, self.device_type): raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}") def init_wrapper_code(self) -> None: - self.cuda = "cuda" in self.device_types - if self.cpp_wrapper: - self.validate_can_generate_cpp_wrapper() - device_types = self.device_types.copy() device_types.discard("cpu") device_types.discard("meta") @@ -1648,13 +1687,18 @@ def init_wrapper_code(self) -> None: "+".join(device_types) ) only_cpu = len(device_types) == 0 - device_type = "cpu" if only_cpu else device_types.pop() + self.device_type = "cpu" if only_cpu else device_types.pop() + + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() - self.device_ops = get_device_op_overrides(device_type) + self.device_ops = get_device_op_overrides(self.device_type) wrapper_code_gen_cls = get_wrapper_codegen_for_device( - device_type, self.cpp_wrapper + self.device_type, self.cpp_wrapper ) - assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported" + assert ( + wrapper_code_gen_cls is not None + ), f"Device {self.device_type} not supported" self.wrapper_code = wrapper_code_gen_cls() if self.const_module: @@ -1672,14 +1716,10 @@ def codegen_with_cpp_wrapper(self) -> Tuple[str, List[Tuple[int, Node]]]: wrapper code and run it to generate autotuned kernel binaries in the first pass; and then generate cpp wrapper code and compile it to a dynamic library in the second pass. """ - if "cuda" in self.device_types: + if any(device in self.device_types for device in ["cuda", "xpu"]): # first pass self.cpp_wrapper = False - # Although triton.store_cubin was OrderedSet in compile_fx, the backward pass didn't pick - # that up. In theory it should work by only setting triton.store_cubin to True here, - # but that will cause a problem when use_runtime_constant_folding is OrderedSet. - with config.patch({"triton.store_cubin": True}): - compiled = self.compile_to_module().call + compiled = self.compile_to_module().call if not config.triton.autotune_at_compile_time: @@ -1902,7 +1942,7 @@ def compile_to_fn(self) -> Any: # Directly return the file path with the compiled code return AotCodeCompiler.compile( - self, code, serialized_extern_kernel_nodes, cuda=self.cuda + self, code, serialized_extern_kernel_nodes, device_type=self.device_type ) else: return self.compile_to_module().call diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 82a23d3e60cf10..caba77371aac00 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -68,9 +68,16 @@ def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: lambda seeds, index: seeds[index], doc="Extract a single seed from the result of inductor_seeds()", ) +# inductor_random() doesn't accept a dtype. +# instead, its lowering always burns in float32, and conversions to a different type +# are explicit in the graph. We therefore need this impl (used during tracing) to hardcoded +# the dtype, so it always faithfully produces a float32 tensor during tracing, +# even if the default dtype is set to something else. random = make_prim( "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", - lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), + lambda size, seed, mode: getattr(torch, mode)( + size, device=seed.device, dtype=torch.float32 + ), doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", ) randint = make_prim( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8f5cb5d3b06290..0929996fe17452 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6,7 +6,6 @@ import functools import itertools import logging -import re import textwrap import traceback from contextlib import nullcontext @@ -72,7 +71,8 @@ extract_read_writes, var_builder, ) -from .ops_handler import OpCounterCSE +from .loop_body import LoopBody +from .ops_handler import OpCounterCSE, OpCountResult from .runtime.benchmarking import benchmarker from .runtime.hints import ReductionHint from .utils import ( @@ -412,6 +412,7 @@ def codegen_reference(self, writer=None): dtype: torch.dtype get_name: Callable[[], str] get_reads: Callable[[], Any] + num_reads: Callable[[], int] get_stride: Callable[[], Any] get_storage_numel: Callable[[], Any] has_exceeded_max_reads: Callable[[], bool] @@ -555,25 +556,25 @@ def _index(ranges, prefix=SymT.INDEX): ] @cache_on_self - def inner_fn_opcount(self): + def inner_fn_opcount(self) -> OpCountResult: opcounter = OpCounterCSE(V.MockHandler()) - with V.set_ops_handler(opcounter), patch.object( FlexibleLayout, "allow_indexing", True ): self.inner_fn(*self.inner_fn_args()) - return opcounter.op_count + return opcounter.getvalue() def inner_fn_args(self): return (self._index(self.ranges),) + @cache_on_self def inner_fn_str(self): return V.KernelFormatterHandler.ir_to_string( self.inner_fn, *self.inner_fn_args() ) def has_large_inner_fn(self): - return self.inner_fn_opcount() > config.realize_opcount_threshold + return self.inner_fn_opcount().num_ops > config.realize_opcount_threshold def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) @@ -594,7 +595,10 @@ def get_reads(self): ).reads def get_read_names(self) -> OrderedSet[str]: - return OrderedSet(dep.name for dep in self.get_reads()) + return OrderedSet(self.inner_fn_opcount().read_buffers) + + def num_reads(self): + return len(self.inner_fn_opcount().read_buffers) def get_reduction_size(self): raise NotImplementedError( @@ -2682,6 +2686,9 @@ def codegen_reference(self, writer=None): dtype=self.layout.dtype, ) + def num_reads(self): + return 1 + @dataclasses.dataclass class DtypeView(BaseView): @@ -2872,12 +2879,7 @@ def is_contiguous_strides_for_shape( def get_align_for_dtype(dtype: torch.dtype) -> int: - """ - CUDA max memory transaction size is 128 bytes for a warp. - We pick `128 // dtype.itemsize` as alighment so GPU can do coalesced - memory access. - """ - return 128 // dtype.itemsize + return config.padding_alignment_bytes // dtype.itemsize @dataclasses.dataclass @@ -3016,29 +3018,12 @@ def _pad_strides(in_strides, size, dtype): # smallest stride to be 1. new_strides[fill_order[0]] = 1 - # Don't align a too small stride since that causes too much memory increase. - # Pad too small stride may also cause perf loss. We may result in many tiny data blocks - # with gaps in between. That causes less coalesced GPU memory access! - # - # Initially we pick 320 as the threshold since for alignement=16, - # that results in at most 5% memory cost. - # - # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. - # Let's say an inner reduction has a row size 513. Inductor will generate - # persistent reduction code. - # If we do padding, the strides are not contiguous any more. Inductor - # uses a much smaller threshold for persistent reduction in this case and - # generates potentially worse non-persistent reduction code. - # - # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. - # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) - align_stride_threshold = 1024 padded = False for rank, idx in enumerate(fill_order[1:], start=1): prev_idx = fill_order[rank - 1] stride = new_strides[prev_idx] * size[prev_idx] - if stride > align_stride_threshold and stride % align != 0: + if stride > config.padding_stride_threshold and stride % align != 0: stride = ceildiv(stride, align) * align padded = True new_strides[idx] = stride @@ -3531,7 +3516,8 @@ def __post_init__(self): class InputBuffer(Buffer): - pass + def num_reads(self): + return 1 class ConstantBuffer(InputBuffer): @@ -3592,9 +3578,11 @@ def get_computed_buffer_name(self): return self.data.name return None - @cache_on_self def num_reads(self): - return len(self.get_read_writes().reads) + return self.data.num_reads() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() def get_read_writes(self): with patch.object(FlexibleLayout, "allow_indexing", True): @@ -3667,12 +3655,6 @@ def get_fill_order(self): self.data.get_pointwise_size(), self.data.get_reduction_size() ) reads = self.get_read_writes().reads - reads_bufs = [ - V.graph.name_to_buffer[r.name] - if r.name in V.graph.name_to_buffer.keys() - else None - for r in reads - ] # only consider reads to buffer of same size # ignore StarDeps because they don't contribute stride information assert all( @@ -3719,6 +3701,7 @@ def get_default_sizes_body(self): self.get_store_function(), (args if self.get_reduction_type() else args[:1]), var_ranges, + *args, ) index_vars = [] reduce_vars: List[Any] = [] @@ -3738,6 +3721,7 @@ def get_default_sizes_body(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): """ This is a main place where we do loop transformations in a @@ -3753,6 +3737,8 @@ def simplify_and_reorder( to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) on CPU by preventing indexing simplifications and obtaining index/reduce ranges for the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. """ ( (index_size, reduce_size), @@ -3760,6 +3746,15 @@ def simplify_and_reorder( (index_vars, reduce_vars), ) = self.get_default_sizes_body() + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + index_formulas = [*body.indexing_exprs.values()] if extra_indexing_constraints is not None: assert ( @@ -3782,40 +3777,59 @@ def simplify_and_reorder( ] index_formulas += extra_indexing_expr - memory_addrs = [*body.writes_name2expr.values()] + memory_addrs = [*body.get_write_exprs()] if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): - memory_addrs.extend(body.reads_name2expr.values()) + memory_addrs.extend(body.get_read_exprs()) - def simplify_and_reorder(x_vars, support_vars, sizes): + def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): sizes, reindex0, reindex1 = self._apply_loop_reordering( x_vars, support_vars, sizes, memory_addrs ) # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] x_vars = reindex0(x_vars) - sizes, reindex2, prune = V.graph.sizevars._simplify_loops( - x_vars, - sizes, - index_prevent_reordering(index_formulas, x_vars, sizes), - ) - reindex = fuse_reindexing(reindex1, reindex2) + + if simplify_loops: + sizes, reindex2, prune = V.graph.sizevars._simplify_loops( + x_vars, + sizes, + index_prevent_reordering(index_formulas, x_vars, sizes), + ) + reindex = fuse_reindexing(reindex1, reindex2) + else: + reindex = reindex1 return sizes, reindex, reindex1 support_vars = index_vars + reduce_vars + should_merge_loops = ( + self.get_device().type != "cuda" or not config.loop_ordering_after_fusion + ) iter_ranges, iter_reindex, _ = simplify_and_reorder( index_vars, support_vars, index_size, + should_merge_loops, ) + + # Like iteration dimensions, we may also want to delay merging reduction dimensions. + # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise + # kernel, merging M and N dimension too early makes it hard to decide what loop order + # we should pick for the piontwise kernel so that it is fusible with the reduction. reduce_ranges, reduce_reindex, _ = simplify_and_reorder( - reduce_vars, support_vars, reduce_size + reduce_vars, support_vars, reduce_size, should_merge_loops ) # retrace the loop body with simplification and reordering applied (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( - iter_ranges, reduce_ranges, prefix="z" + iter_ranges, + reduce_ranges, + prefix="z", ) body = LoopBody( - body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges + body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, ) return (iter_ranges, reduce_ranges), body @@ -3886,9 +3900,9 @@ def __init__(self, layout, inputs, make_kernel_render): V.graph.register_operation(self) def get_read_writes(self): - return self.normalized_read_writes() + return self.extract_read_writes(normalize=True) - def normalized_read_writes(self): + def extract_read_writes(self, normalize): name = self.get_name() indexer = self.layout.make_indexer() @@ -3897,7 +3911,7 @@ def dummy(index, rindex): return ops.store(name, indexer(index), "fake") deps = dependencies.extract_read_writes( - dummy, self.get_size(), (), normalize=True + dummy, self.get_size(), (), normalize=normalize ) deps.reads = OrderedSet(dependencies.StarDep(x.get_name()) for x in self.inputs) return deps @@ -3917,6 +3931,7 @@ def should_allocate(self): def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ): return ( ( @@ -4149,6 +4164,9 @@ def unwrap_storage(inputs): def is_extern(self): return True + def num_reads(self): + return 1 + class NopKernel(InputsKernel): def is_no_op(self): @@ -5203,10 +5221,19 @@ def codegen(self, wrapper): raw_args = [ self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel ] + + # NOTE: raw_args doesn't include autotuned args. + # But, kernel.constexprs includes indices of autotuned args. + # So, let's recalculate constexpr indices wrt to raw_args. + constexpr_indices = [] + for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel): + if kernel.arg_names.index(kwarg) in kernel.constexprs: + constexpr_indices.append(idx) + # Call to kernel self.codegen_comment(wrapper) wrapper.generate_user_defined_triton_kernel( - new_name, raw_args, self.grid, configs, triton_meta, kernel.constexprs + new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices ) def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: @@ -5844,7 +5871,11 @@ def go(expr, keypath): elif isinstance(keypath[0], CallMethodKey): return go(f"{expr}.{keypath[0].name}()", keypath[1:]) elif isinstance(keypath[0], pytree.SequenceKey): - return go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) elif isinstance(keypath[0], DivideByKey): # TODO: need to assert divisibility # TODO: this is invalid C++ codegen @@ -6014,17 +6045,9 @@ def codegen(self, wrapper): if V.graph.cpp_wrapper: from torchgen.aoti.fallback_ops import inductor_fallback_ops - if ( - config.is_fbcode() - and kernel not in has_c_shim + if config.abi_compatible and str(kernel) not in inductor_fallback_ops: # C shim v2 is torchgen-ed, which should cover all aten ops. - # If you do hit a missed op, please update gen_aoti_c_shim.py. - and config.c_shim_version == "1" - ) or ( - config.abi_compatible - and config.c_shim_version == "2" - and str(kernel) not in inductor_fallback_ops - ): + # If you do hit a missed op, please update fallback_ops.py. log.warning( "%s is missing a c-shim implementation, using proxy executor as fallback", kernel, @@ -6361,8 +6384,7 @@ def realize_hint(self): """ if ( isinstance(self.data, (Pointwise, Reduction)) - and self.num_reads() > 1 - and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() + and self.data.inner_fn_opcount().nontrivial_read_count > 1 ): self.realize() @@ -6372,63 +6394,30 @@ def has_exceeded_max_reads(self): or self.has_large_inner_fn() ) - def mark_reuse(self, users): + def should_realize_on_reuse(self, users): """ A heuristic to decide if we should realize a tensor that is used multiple times. """ - - def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): - """ - The heuristic for realizing reused result of heavy ops on cpu - """ - heavy_ops = ["exp", "sigmoid"] # a list of heavy ops - fn_str = loops.inner_fn_str() - return any((op + "(") in fn_str for op in heavy_ops) - - if ( - users > 1 - and isinstance(self.data, (Pointwise, Reduction)) - and ( + if users > 1 and isinstance(self.data, (Pointwise, Reduction)): + if is_cpu(self.data): + # Heuristic for realizing reused result of heavy ops on cpu + opcount = self.data.inner_fn_opcount() + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops + if any(x in opcount.used_ops for x in heavy_ops): + return True + return ( self.num_reads() > config.realize_reads_threshold or self.has_large_inner_fn() - or (is_cpu(self.data) and should_realize_on_cpu(self.data)) ) - ): + return False + + def mark_reuse(self, users): + if self.should_realize_on_reuse(users): self.realize() - @cache_on_self def num_reads(self): - data = self.data - if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): - return 1 - if isinstance(data, ComputedBuffer): - read_writes = data.get_read_writes() - else: - assert isinstance(data, (Pointwise, Reduction)), type(data) - read_writes = ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=data.get_device(), - dtype=data.get_dtype(), - size=data.get_size(), - ), - data=data, - ).get_read_writes() - return len(read_writes.reads) - - @cache_on_self - def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): - # Skip the check for non Pointwise instances - return ( - (sum(read.index != 0 for read in self.data.get_reads()) > 1) - if isinstance(self.data, Pointwise) - and all( - not isinstance(read, dependencies.StarDep) - for read in self.data.get_reads() - ) - else True - ) + return self.data.num_reads() @dataclasses.dataclass @@ -6508,7 +6497,7 @@ def create( subgraph.graph.run(*fake_operands) true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] - false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] + false_outputs = false_fn.graph.graph_outputs # type: ignore[union-attr] for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): if _has_aliased_buffers(true_outputs): @@ -6754,313 +6743,6 @@ def codegen_reference(self, writer=None): return self.name -class InterpreterShim(torch.fx.Interpreter): - @staticmethod - @functools.lru_cache(None) - def _dummy_gm(): - return torch.fx.symbolic_trace(identity) - - def __init__(self, graph, submodules): - # call super() with a placeholder to avoid constructing a - # GraphModule which is very expensive (it does codegen). - super().__init__(self._dummy_gm(), garbage_collect_values=False) - self.module = self # type: ignore[assignment] - self.graph = graph - self.submodules = submodules - self.extra_traceback = False - self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] - self.current_node = None - - def run_node(self, n: torch.fx.Node) -> Any: - self.current_node = n - return super().run_node(n) - - def run(self, *args, **kwargs): - with V.set_interpreter_handler(self): - return super().run(*args, **kwargs) - - -class LoopBody: - """ - Captures the body of a Loops subclass into an FX graph. Persists any - indexing simplifications and makes it easier to analyze loop bodies. - """ - - def __init__(self, fn, args, var_ranges): - super().__init__() - self.var_ranges = var_ranges - self.indexing_exprs = {} - self.indexing_exprs_name = {} - self.reads = [] - self.writes = [] - self.reads_name2expr = {} - self.writes_name2expr = {} - self.other = [] - self.submodules = {"get_index": self.get_index} - self.subblocks = {} - self.indirect_vars = [] - self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} - self.root_block = LoopBodyBlock(self, fn, args) - self.indexing = None - - @cache_on_self - def get_nodes(self): - all_graphs = itertools.chain( - (self.root_block.graph,), - (block.graph for block in self.subblocks.values()), - ) - return [node for graph in all_graphs for node in graph.nodes] - - @cache_on_self - def bounds(self): - # Doing a local import to avoid dumping all the code here - from .bounds import BoundVars - - return BoundVars(self) - - def debug_str(self): - lines = [f"var_ranges = {dict(self.var_ranges)}"] - lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) - lines.extend( - [ - block.debug_str(name) - for name, block in itertools.chain( - [("body", self.root_block)], self.subblocks.items() - ) - ] - ) - return "\n".join(lines) - - def add_index_expr(self, expr: sympy.Expr, category, buf_name): - getattr(self, category).append(expr) - if buf_name is not None: - getattr(self, f"{category}_name2expr")[buf_name] = expr - if expr not in self.indexing_exprs_name: - name = f"index{len(self.indexing_exprs)}" - self.indexing_exprs_name[expr] = name - self.indexing_exprs[name] = expr - return self.indexing_exprs_name[expr] - - def add_submodule(self, block, prefix): - """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" - if prefix[-1].isnumeric() and prefix not in self.submodules: - name = prefix - else: - name = f"{prefix}{len(self.submodules)}" - self.submodules[name] = block - return name - - def add_indirect(self, size): - var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) - assert var not in self.indirect_var_ranges - self.indirect_vars.append(var) - self.indirect_var_ranges[var] = size - return var - - def replace_indirect(self, old, new): - """Swap in a variable used in indirect indexing""" - if str(old) == str(new): - return - assert self.indexing is not None - self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} - - def get_index(self, name): - assert self.indexing is not None - return self.indexing[name] - - def indexing_from_args(self, indices): - index = [*itertools.chain.from_iterable(indices)] - assert len(index) == len(self.var_ranges), (index, self.var_ranges) - assert all(v not in self.var_ranges for v in index) - replacements = dict(zip(self.var_ranges.keys(), index)) - return { - name: sympy_subs(expr, replacements) - for name, expr in self.indexing_exprs.items() - } - - def __call__(self, *indices): - self.indexing = self.indexing_from_args(indices) - result = self.root_block() - self.indexing = None - return result - - -class LoopBodyBlock: - """ - Captures the body of a Loops subclass into an FX graph. - In normal cases there will be a 1:1 mapping between LoopBody and - LoopBodyBlock, hower in the case of ops.masked() the masked out - operations will manifest as an extra LoopBodyBlock. - """ - - def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): - self.body = body - - def add_index(expr, category, buf_name=None): - return tracer.create_proxy( - "call_module", - "get_index", - (self.body.add_index_expr(expr, category, buf_name),), - {}, - ) - - class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - self.name = "CaptureIndexing" - - def load(self, name: str, index: sympy.Expr): - index = add_index(index, "reads", name) - return self._inner.load(name, index) - - def store(self, name, index, value, mode=None): - index = add_index(index, "writes", name) - return self._inner.store(name, index, value, mode) - - def store_reduction(self, name, index, value): - index = add_index(index, "writes", name) - return self._inner.store_reduction(name, index, value) - - def reduction(self, dtype, src_dtype, reduction_type, value): - result = self._inner.reduction(dtype, src_dtype, reduction_type, value) - if "welford" in reduction_type: - return tuple(result[i] for i in range(3)) - return result - - def index_expr(self, index, dtype): - if isinstance(index, (int, sympy.Integer)): - return self._inner.constant(int(index), dtype) - index = add_index(index, "other") - return self._inner.index_expr(index, dtype) - - def check_bounds(self, index, size, lower, upper): - index = add_index(index, "other") - size = add_index(size, "other") - return self._inner.check_bounds(index, size, lower, upper) - - def bucketize( - self, - values, - offsets_name: str, - offsets_size: sympy.Expr, - indexing_dtype: torch.dtype, - right: bool, - ): - offsets_size = add_index(offsets_size, "other") - return self._inner.bucketize( - values, offsets_name, offsets_size, indexing_dtype, right - ) - - @staticmethod - def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): - """ - Recursively capture the masked out body in another LoopBodyBlock - """ - - subblock: LoopBodyBlock - - def shim(mask, other): - return V.ops.masked(mask, subblock, other) - - name = self.body.add_submodule(shim, "masked_subblock") - subblock = LoopBodyBlock(self.body, masked_body, []) - self.body.subblocks[name] = subblock - return tracer.create_proxy( - "call_module", name, (mask_proxy, other_proxy), {} - ) - - @staticmethod - def scan( - dtype_proxy, - combine_fn: Callable[ - [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] - ], - value_proxy, - ): - def shim(dtypes, values): - return V.ops.scan(dtypes, combine_fn, values) - - name = self.body.add_submodule(shim, "scan") - result = tracer.create_proxy( - "call_module", - name, - (dtype_proxy, value_proxy), - {}, - ) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(value_proxy))) - - def sort(self, dtypes, values, stable, descending): - result = self._inner.sort(dtypes, values, stable, descending) - # Proxies are iterable, but some methods expect tuples/lists - return tuple(result[i] for i in range(len(values))) - - def frexp(self, value_proxy): - result = self._inner.frexp(value_proxy) - # Proxies are iterable, but some methods expect tuples/lists - return (result[0], result[1]) - - @staticmethod - def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): - """ - Flow data from tensors into indexing formulas. - Introduce a call_module to update the indexing. - """ - - var = self.body.add_indirect(size) - - def set_indirect(new_var): - self.body.replace_indirect( - var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) - ) - - tracer.create_proxy( - "call_module", - self.body.add_submodule(set_indirect, f"set_{var}"), - (index_proxy,), - {}, - ) - return var - - @staticmethod - def output(result): - tracer.create_proxy("output", "output", (result,), {}) - - tracer = torch.fx.Tracer() - tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) - proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) - - from .index_propagation import IndexPropagation - from .sizevars import SimplifyIndexing - - handler: Any = SimplifyIndexing( - CaptureIndexing(proxy_ops), self.body.var_ranges - ) - if config.constant_and_index_propagation: - handler = IndexPropagation( - handler, self.body.var_ranges, self.body.indirect_var_ranges - ) - - with V.set_ops_handler(handler): - # This indirection is just a cute way to get IndexPropagation to - # unwrap the return value. - ops.output(fn(*args)) - self.graph = tracer.graph - - def __call__(self): - graph = self.graph - submodules = self.body.submodules - - return InterpreterShim(graph, submodules).run(V.get_ops_handler()) - - def debug_str(self, name="block"): - code = torch.fx.GraphModule(self.body.submodules, self.graph).code - return re.sub( - # strip `; del var0` suffixes to make output prettier - r";[^\n]*", - "", - code.strip().replace("def forward(", f"def {name}("), - ) - - class _CollectiveKernel(FallbackKernel): def should_allocate(self): return False diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index de7ff7ee913610..d6dfca28662b3f 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -3,7 +3,7 @@ import logging import math -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Sequence, Tuple import sympy @@ -17,9 +17,11 @@ ExternKernel, FixedLayout, FlexibleLayout, + get_stride_order, InputBuffer, IRNode, StorageBox, + stride_order2fill_order, Subgraph, TensorBox, ) @@ -29,6 +31,29 @@ log = logging.getLogger(__name__) aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len( + fill_order + ), "Length of sizes must match the length of the fill order" + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta): @@ -55,6 +80,13 @@ def maybe_realize(args: List[Optional[IRNode]]): return tree_map(lambda x: realize_inputs(x) if x is not None else None, args) +def get_float32_precision(): + if torch.get_float32_matmul_precision() == "highest" or torch.version.hip: + return "'ieee'" + else: + return "'tf32'" + + def build_subgraph_buffer( args: List[TensorBox], subgraph: Subgraph, @@ -172,22 +204,27 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - Z = {{size("Q", 0)}} + ZQ = {{size("Q", 0)}} HQ = {{size("Q", 1)}} Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} KV_LEN = {{size("K", 2)}} MATMUL_PRECISION = Q.dtype.element_ty q_start = tl.program_id(0) - off_z = tl.program_id(1) // HQ + off_zq = tl.program_id(1) // HQ off_hq = tl.program_id(1) % HQ + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV off_hkv = off_hq // GQA_SHARED_HEADS off_g = off_hq % GQA_SHARED_HEADS - q_offset = off_z * stride_qz + off_hq * stride_qh - k_offset = off_z * stride_kz + off_hkv * stride_kh - v_offset = off_z * stride_vz + off_hkv * stride_vh + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh Q = Q + q_offset K = K + k_offset @@ -196,14 +233,15 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - sparse_idx_z = off_z % SPARSE_Z + sparse_idx_z = off_zq % SPARSE_Z sparse_idx_hq = off_hq % SPARSE_HQ SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - SPARSE_Q_BLOCK_CNT: tl.constexpr = tl.cdiv(Q_LEN, SPARSE_Q_BLOCK_SIZE) - SPARSE_KV_BLOCK_CNT: tl.constexpr = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -214,8 +252,8 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK # KV_IDX and KV_NUM_BLKS are always contiguous. sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq - sparse_kv_num_blks_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT + q_start // SPARSE_Q_MULTIPLE - sparse_kv_idx_offset = sparse_hz_offset * SPARSE_Q_BLOCK_CNT * SPARSE_KV_BLOCK_CNT + (q_start // SPARSE_Q_MULTIPLE) * SPARSE_KV_BLOCK_CNT # noqa: B950 + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 Q_block_ptr = tl.make_block_ptr( base=Q, @@ -231,7 +269,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK q = tl.load(Q_block_ptr) else: # boundary check is not free, so we only do it when necessary. - q = tl.load(Q_block_ptr, boundary_check=(0,)) + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option = "zero") # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We don't know anything "special" about these blocks, so we need to apply @@ -239,7 +277,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK kv_indices = KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K, @@ -263,7 +301,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK {{gen_argdefs()}}, q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, acc, l_i, m_i, - off_z, off_hq, offs_m[:, None], offs_n[None, :], + off_zq, off_hq, offs_m[:, None], offs_n[None, :], kv_indices, kv_num_blocks, 0, block_n_end, MATMUL_PRECISION, @@ -278,7 +316,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK kv_indices = FULL_KV_IDX + sparse_kv_idx_offset kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K, @@ -302,7 +340,7 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK {{gen_argdefs()}}, q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN, acc, l_i, m_i, - off_z, off_hq, offs_m[:, None], offs_n[None, :], + off_zq, off_hq, offs_m[:, None], offs_n[None, :], kv_indices, kv_num_blocks, 0, block_n_end, MATMUL_PRECISION, @@ -314,18 +352,16 @@ def get_offset_for_next_block(loop_iter, col_indices, total_blocks, SPARSE_BLOCK # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step l_i = tl.where(l_i == 0.0, 1, l_i) - masked_out_rows = (m_i == float("-inf")) - m_i = tl.where(masked_out_rows, 0, m_i) acc = acc / l_i[:, None] - idx_z = tl.program_id(1) // HQ + idx_zq = tl.program_id(1) // HQ idx_hq = tl.program_id(1) % HQ idx_m = offs_m[:, None] idx_d = tl.arange(0, V_HEAD_DIM)[None, :] mask = idx_m < Q_LEN # TODO generalize and add proper mask support - {{store_output(("idx_z", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} # TODO dont want to write this if we dont require grad if OUTPUT_LOGSUMEXP: @@ -429,9 +465,9 @@ def forward_block_mn( if IS_DIVISIBLE: k = tl.load(K_block_ptr) else: - k = tl.load(K_block_ptr, boundary_check=(1,)) + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option = "zero") # -- compute qk --- - qk = tl.dot(q, k) # TODO: use cuda matmul when q_len <= 2. + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. if not PRESCALE_QK: qk *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ @@ -439,8 +475,11 @@ def forward_block_mn( # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, # which is larger than the actual number of elements. To avoid access memory out of bound, # we need to mask out the elements that are out of Q_LEN & KV_LEN. - offs_m = offs_m % Q_LEN - offs_n = offs_n % KV_LEN + m = offs_m % Q_LEN + n = offs_n % KV_LEN + else: + m = offs_m + n = offs_n {{ modification( subgraph_number=0, @@ -448,8 +487,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="offs_m", - n="offs_n", + m="m", + n="n", out="qk" ) | indent_except_first(1) }} @@ -464,8 +503,8 @@ def forward_block_mn( score="qk", b="off_z", h="off_h", - m="offs_m", - n="offs_n", + m="m", + n="n", ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: @@ -499,8 +538,8 @@ def forward_block_mn( if IS_DIVISIBLE: v = tl.load(V_block_ptr) else: - v = tl.load(V_block_ptr, boundary_check=(0,)) - acc = tl.dot(p.to(MATMUL_PRECISION), v, acc) + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option = "zero") + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) # -- update m_i m_i = m_ij @@ -688,6 +727,7 @@ def flex_attention( mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) kernel_options = dict(kernel_options) + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) if _use_flex_decoding(query, kernel_options): return create_flex_decoding_kernel( query, @@ -732,7 +772,9 @@ def flex_attention( Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert Bq == Bkv, "Batch dimension must match" + assert V.graph.sizevars.evaluate_expr( + sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) + ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" B = Bq if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: @@ -744,11 +786,18 @@ def flex_attention( # This works because only the last dim differs and we check it is contiguous. q_strides = query.get_stride() assert q_strides[-1] == 1, "Query must be contiguous in the last dimension" + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + stride_order = get_stride_order(query.get_stride()) + fill_order = stride_order2fill_order(stride_order) + out_strides = construct_strides(out_size, fill_order) + layout = FixedLayout( query.get_device(), query.get_dtype(), [B, Hq, seq_len_q, v_head_dim], - query.get_stride(), + stride=out_strides, ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = [B, Hq, seq_len_q] @@ -787,6 +836,16 @@ def flex_attention( (64, 64, 4, 3), ] + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. @@ -931,10 +990,11 @@ def flex_attention_backward_grid( stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} - Z = {{size("Q", 0)}} + ZQ = {{size("Q", 0)}} HQ = {{size("Q", 1)}} HKV = {{size("K", 1)}} Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} KV_LEN = {{size("K", 2)}} MATMUL_PRECISION = Q.dtype.element_ty @@ -944,17 +1004,20 @@ def flex_attention_backward_grid( NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) off_hz = tl.program_id(2) - off_z = off_hz // HKV # batch idx + off_zq = off_hz // HKV # q batch idx off_hkv = off_hz % HKV # kv head idx + off_zkv = off_zq % ZKV # kv batch idx SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - sparse_idx_z = off_z % SPARSE_Z + sparse_idx_z = off_zq % SPARSE_Z - k_adj = (stride_kh * off_hkv + stride_kz * off_z).to(tl.int64) - v_adj = (stride_vh * off_hkv + stride_vz * off_z).to(tl.int64) - dv_adj = (stride_dvh * off_hkv + stride_dvz * off_z).to(tl.int64) + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) # offset K, V, DV pointers for batch/kv-head K += k_adj @@ -984,10 +1047,10 @@ def flex_attention_backward_grid( sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. - q_adj2 = (stride_qh * off_hq2 + stride_qz * off_z).to(tl.int64) - do_adj2 = (stride_doh * off_hq2 + stride_doz * off_z).to(tl.int64) - dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_z).to(tl.int64) - off_chz2 = ((off_z * HQ + off_hq2) * Q_LEN).to(tl.int64) + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) Q2 = Q + q_adj2 DO2 = DO + do_adj2 @@ -1019,6 +1082,7 @@ def flex_attention_backward_grid( else: Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) lse = lse[:, None] # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1032,7 +1096,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, K, V, dq, q, do, Di, lse, - off_z, off_hq2, offs_m2, offs_n2, + off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, @@ -1051,7 +1115,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, K, V, dq, q, do, Di, lse, - off_z, off_hq2, offs_m2, offs_n2, + off_zq, off_hq2, offs_m2, offs_n2, stride_kn, stride_kd, stride_vn, stride_vd, kv_indices, sparse_kv_num_blocks, MATMUL_PRECISION, @@ -1096,10 +1160,10 @@ def flex_attention_backward_grid( off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g # Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. - q_adj1 = (stride_qh * off_hq1 + stride_qz * off_z).to(tl.int64) - do_adj1 = (stride_doh * off_hq1 + stride_doz * off_z).to(tl.int64) - dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_z).to(tl.int64) - off_chz1 = ((off_z * HQ + off_hq1) * Q_LEN).to(tl.int64) + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) Q1 = Q + q_adj1 DO1 = DO + do_adj1 @@ -1125,7 +1189,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, - off_z, off_hq1, offs_n1, offs_m1, + off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, @@ -1145,7 +1209,7 @@ def flex_attention_backward_grid( {{gen_argdefs()}}, Q1, DO1, DELTA1, LSE1, dk, dv, k, v, - off_z, off_hq1, offs_n1, offs_m1, + off_zq, off_hq1, offs_n1, offs_m1, stride_qm, stride_qd, stride_dom, stride_dod, q_indices, sparse_q_num_blocks, MATMUL_PRECISION, @@ -1165,7 +1229,10 @@ def flex_attention_backward_grid( dk *= SM_SCALE mask = index_n < KV_LEN - {{store_output(("off_z", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} @triton.jit def bwd_dq_inner( @@ -1192,7 +1259,7 @@ def bwd_dq_inner( # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N2, 1)) + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) if not IS_DIVISIBLE: if hi >= 1: for start_n in range(0, hi - 1): @@ -1268,7 +1335,7 @@ def bwd_dq_block_mn( kT = tl.load(kT_ptrs) else: kT = tl.load(kT_ptrs, mask=offs_n2[None, :] < KV_LEN) - qk = tl.dot(q, kT) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) if not PRESCALE_QK: qk *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ @@ -1318,7 +1385,7 @@ def bwd_dq_block_mn( vT = tl.load(vT_ptrs) else: vT = tl.load(vT_ptrs, mask=offs_n2[None, :] < KV_LEN) - dp = tl.dot(do, vT) + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) ds = p * (dp - Di[:, None]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( @@ -1344,7 +1411,7 @@ def bwd_dq_block_mn( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ds = ds.to(MATMUL_PRECISION) # Compute dQ. - dq += tl.dot(ds, tl.trans(kT)) + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) return dq @@ -1373,7 +1440,7 @@ def bwd_dkdv_inner( do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(Q_LEN // BLOCK_M1, 1)) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) if not IS_DIVISIBLE: if hi >= 1: @@ -1451,7 +1518,8 @@ def bwd_dkdv_block_mn( else: qT = tl.load(qT_ptrs, mask=offs_m1[None, :] < Q_LEN) lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) - qkT = tl.dot(k, qT) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) if not PRESCALE_QK: qkT *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ @@ -1501,13 +1569,13 @@ def bwd_dkdv_block_mn( do = tl.load(do_ptrs, mask=offs_m1[:, None] < Q_LEN) # Compute dV. ppT = pT - dv += tl.dot(ppT.to(MATMUL_PRECISION), do) + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) if IS_DIVISIBLE: Di = tl.load(DELTA + offs_m1) else: Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do)) + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) dsT = pT * (dpT - Di[None, :]) # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ {{ modification( @@ -1530,7 +1598,7 @@ def bwd_dkdv_block_mn( # (grads) apply mask for partially unmasked block dsT = tl.where(mask_mod_output, dsT, 0.0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT)) + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) return dk, dv """ @@ -1603,17 +1671,18 @@ def flex_attention_backward(*args, **kwargs): ] ) - if _use_flex_decoding(query, kernel_options): - raise NotImplementedError("Flex decoding backward pass is not implemented. ") - device = query.get_device() dtype = query.get_dtype() Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert Bq == Bkv, "Batch dimension must match" + + assert V.graph.sizevars.evaluate_expr( + sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) + ), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" B = Bq kernel_options = dict(kernel_options) + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: kernel_options.setdefault("IS_DIVISIBLE", False) else: @@ -1653,10 +1722,10 @@ def flex_attention_backward(*args, **kwargs): mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph ) - layout_k = FixedLayout( + layout_broadcasted_k = FixedLayout( key.get_device(), key.get_dtype(), - key.get_size(), + [Bq, Hkv, seq_len_kv, qk_head_dim], key.get_stride(), ) @@ -1673,8 +1742,11 @@ def flex_attention_backward(*args, **kwargs): grad_query = empty_strided( query.get_size(), query.get_stride(), dtype=dtype, device=device ) - grad_value = empty_strided( - value.get_size(), value.get_stride(), dtype=dtype, device=device + broadcasted_grad_value = empty_strided( + (Bq, *value.get_size()[1:]), + value.get_stride(), + dtype=dtype, + device=device, ) kernel_options.setdefault("SM_SCALE", scale) @@ -1737,7 +1809,7 @@ def flex_attention_backward(*args, **kwargs): delta, grad_out, grad_query, - grad_value, + broadcasted_grad_value, kv_num_blocks, kv_indices, q_num_blocks, @@ -1747,9 +1819,9 @@ def flex_attention_backward(*args, **kwargs): full_q_num_blocks, full_q_indices, ], - layout=layout_k, # We use store_output only for grad_key + layout=layout_broadcasted_k, # We use store_output only for grad_key subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer, mask_graph_buffer], - mutated_inputs=[grad_query, grad_value], + mutated_inputs=[grad_query, broadcasted_grad_value], call_sizes=query.get_size() + key.get_size()[1:3], num_stages=num_stages, num_warps=num_warps, @@ -1764,7 +1836,7 @@ def flex_attention_backward(*args, **kwargs): delta, grad_out, grad_query, - grad_value, + broadcasted_grad_value, kv_num_blocks, kv_indices, q_num_blocks, @@ -1788,13 +1860,24 @@ def flex_attention_backward(*args, **kwargs): 15: create_indices_fake, } - grad_key = autotune_select_algorithm( + broadcasted_grad_key = autotune_select_algorithm( "flex_attention_backward", choices, inputs_for_autotuning, - layout_k, + layout_broadcasted_k, input_gen_fns=input_gen_fns, - ) + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + if Bq == Bkv: + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert ( + Bq > 1 and Bkv == 1 + ), f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}" + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + return ( grad_query, grad_key, diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 3520b618ef0a67..c758a3bcfbc961 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -83,6 +83,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} HKV = {{size("Q", 1)}} G: tl.constexpr = GQA_SHARED_HEADS HQ = HKV * G @@ -97,12 +98,13 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV off_hkv = tl.program_id(0) % HKV off_t = tl.program_id(1) q_offset = off_z * stride_qz + off_hkv * stride_qh - k_offset = off_z * stride_kz + off_hkv * stride_kh - v_offset = off_z * stride_vz + off_hkv * stride_vh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} @@ -161,7 +163,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me # first kv block we're loading # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( base=K + k_offset, @@ -207,15 +209,15 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(KV_LEN // BLOCK_N, 1)) + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM, BLOCK_N), - order=(0, 1) + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM, BLOCK_N), + order=(0, 1) ) V_block_ptr = tl.make_block_ptr( base=V + v_offset, @@ -277,7 +279,7 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me idx_hq = off_hkv*G + off_g[:, None, None] idx_m = off_m[None, :, None] idx_d = offs_vd[None, None, :] - # TODO generalize and add proper mask support + mask = (idx_m < Q_LEN) acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} @@ -338,7 +340,10 @@ def create_flex_decoding_kernel(*args, **kwargs): Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() - assert Bq == Bkv, "Batch dimension must match" + + if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): + raise RuntimeError(f"Bq and Bkv must broadcast. Got Bq={Bq} and Bkv={Bkv}") + B = Bq kernel_options = dict(kernel_options) @@ -469,6 +474,8 @@ def create_flex_decoding_kernel(*args, **kwargs): ) # TODO: This feels sketchy kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function @@ -546,8 +553,8 @@ def create_flex_decoding_kernel(*args, **kwargs): # See [Note] Handle fully masked out rows: # g_M Is the global max among split kv blocks. masked_rows = lowerings[aten.eq](g_M, -float("inf")) - g_M = lowerings[aten.where](masked_rows, 0.0, g_M) adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) alpha = lowerings[aten.exp2](adj_M) buf_L = lowerings[aten.mul](buf_L, alpha) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 0c66da4ca4f314..21ba6c1e215dbe 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -2,7 +2,7 @@ import functools import itertools import logging -from typing import cast, List, Tuple +from typing import cast, Sequence, Tuple import sympy @@ -28,7 +28,7 @@ def filtered_configs( m: int, n: int, k: int, - configs: List[Tuple[int, int, int, int, int]], + configs: Sequence[Tuple[int, int, int, int, int]], has_int8_tensor=False, ): """Heuristic to shrink configs when they are bigger than the input size""" diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index 2372420c1d58b6..2f0d020716a199 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -192,6 +192,18 @@ aten__fp8_mm = ExternKernelChoice(torch._scaled_mm, "at::_scaled_mm") +def are_compatible_scales(size_a: List[int], size_b: List[int]) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + def scaled_mm_options( # type: ignore[no-untyped-def] config, # triton.Config sym_m: sympy.core.numbers.Integer, @@ -207,10 +219,11 @@ def scaled_mm_options( # type: ignore[no-untyped-def] sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] ) - assert len(scale_a.get_size()) == len( - scale_b.get_size() - ), "Expect inverse scale_a and scale_b to be both scalars (tensor-wise scaling) or tensors (rowwise scaling)." - + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) return dict( GROUP_M=8, EVEN_K=even_k_symbolic, @@ -220,7 +233,7 @@ def scaled_mm_options( # type: ignore[no-untyped-def] num_stages=config.num_stages, num_warps=config.num_warps, # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) != 0, + SCALING_ROWWISE=len(scale_a.get_size()) == 2, **config.kwargs, ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py new file mode 100644 index 00000000000000..70a7d2fdd7967b --- /dev/null +++ b/torch/_inductor/loop_body.py @@ -0,0 +1,604 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import itertools +import re +from enum import auto, Enum +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple + +import sympy + +import torch.fx +from torch._dynamo.utils import identity +from torch.fx.proxy import Scope, TracerBase +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs +from .virtualized import ops, V + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.lru_cache(None) + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +# We don't need the nn.Module and constant handling in Tracer +class LightTracer(TracerBase): + def __init__(self): + super().__init__() + self.graph = torch.fx.Graph(tracer_cls=self.__class__) # type: ignore[arg-type] + self.scope = Scope("", None) + self.module_stack = {} # type: ignore[assignment] + self.node_name_to_scope = {} + + +class MemoryEntry(NamedTuple): + index_name: str # LoopBody.indexing_exprs[index_name] + buffer_name: Optional[str] + mode: Optional[str] # V.ops.store(..., mode=mode) + + +class MemoryUsageType(Enum): + # These are 1:1 with the opcode generating the usage + LOAD = auto() + LOAD_SEED = auto() + STORE = auto() + STORE_REDUCTION = auto() + INDEX_EXPR = auto() + CHECK_BOUNDS = auto() + BUCKETIZE = auto() + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + indexing_exprs: Dict[str, sympy.Expr] + indexing_exprs_name: Dict[sympy.Expr, str] + submodules: Dict[str, Any] + subblocks: Dict[str, LoopBodyBlock] + indirect_vars: List[str] + indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] + + def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): + super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars + self.var_ranges = var_ranges + + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} + self.memory_usage = {t: [] for t in MemoryUsageType} + self.root_block = LoopBodyBlock(self, fn, args) # traces + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args) + self.indexing_exprs = { + name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) + for name, expr in indexing_exprs.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = other.indirect_vars + self.indirect_var_ranges = other.indirect_var_ranges + self.memory_usage = other.memory_usage + self.root_block = other.root_block.clone(self) + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } + + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + # if iter_sizes == old_iter_sizes: + # # no dimensions get merged. + # return old_sizes, old_body + + # Note: if no dimension get merges, the symbol prefix will + # remain 'y'. But if we merge dimensions, we change prefix to + # 'z'. If this is an issue, we can always retrace the LoopBody + # to change symbol prefix to 'z'. + # + # There is indeed an issue due to symbol name conflicting. + # y0 maybe reused for the y dimension later. + ( + iter_vars, + reduce_vars, + ), var_ranges = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="t" + ) + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + + # use the original symbol prefix + # Can try to optimize if this is a bottleneck for compilation time + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + iter_sizes, reduce_sizes, prefix="z" + ) + new_body2 = LoopBody( + new_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body2 + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + from .ir import same_reorder + + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="t" # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = list(itertools.chain(*indices)) + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, prefix="z" # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def get_read_expr(self, buffer_name): + # reversed to match old behavior + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_write_expr(self, buffer_name): + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_read_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in self.memory_usage[MemoryUsageType.LOAD] + ] + + def get_write_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ) + ] + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def is_memory_copy(self) -> bool: + """ + True of this contains only a single loads and store. + Note, this could involve a layout change. + """ + return ( + len(self.memory_usage[MemoryUsageType.LOAD]) == 1 + and len(self.memory_usage[MemoryUsageType.STORE]) == 1 + and len(self.submodules) == 1 # get_index + and self.root_block.contains_only_ops(("load", "store")) + ) + + __repr__ = debug_str + + def add_index_expr( + self, + expr: sympy.Expr, + mtype: MemoryUsageType, + buffer_name: Optional[str] = None, + mode: Optional[str] = None, + ): + name = self.indexing_exprs_name.get(expr) + if not name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) + return name + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) + assert var not in self.indirect_var_ranges + self.indirect_vars.append(var) + self.indirect_var_ranges[var] = size + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def indexing_from_args(self, indices): + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert all( + v not in self.var_ranges for v in index + ), f"{self.var_ranges=}, {indices=}" + replacements = dict(zip(self.var_ranges.keys(), index)) + return { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + def __call__(self, *indices): + self.indexing = self.indexing_from_args(indices) + result = self.root_block() + self.indexing = None + return result + + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, hower in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + self.body = body + + def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): + return tracer.create_proxy( + "call_module", + "get_index", + (body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] + self.name = "CaptureIndexing" + + def load(self, name: str, index: sympy.Expr): + index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + if "welford" in reduction_type: + return tuple(result[i] for i in range(3)) + return result + + def index_expr(self, index, dtype): + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + offsets_size = add_index( + offsets_size, MemoryUsageType.BUCKETIZE, buffer_name=offsets_name + ) + return self._inner.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + + @staticmethod + def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + @staticmethod + def scan( + dtype_proxy, + combine_fn: Callable[ + [Tuple[Any, ...], Tuple[Any, ...]], Tuple[Any, ...] + ], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + @staticmethod + def indirect_indexing(index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim( + var, size, check, wrap_neg + ) + tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + @staticmethod + def output(result): + tracer.create_proxy("output", "output", (result,), {}) + + tracer = LightTracer() + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + from .sizevars import SimplifyIndexing + + handler: Any = SimplifyIndexing( + CaptureIndexing(proxy_ops), self.body.var_ranges + ) + if config.constant_and_index_propagation: + handler = IndexPropagation( + handler, self.body.var_ranges, self.body.indirect_var_ranges + ) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + def contains_only_ops(self, allowed_ops) -> bool: + return all( + node.target in allowed_ops + for node in self.graph.find_nodes(op="call_method") + ) + + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 86872e46a0b4c0..adc199a20c1bcf 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -18,10 +18,7 @@ import torch.fx import torch.utils._pytree as pytree from torch._higher_order_ops.associative_scan import associative_scan_op -from torch._higher_order_ops.triton_kernel_wrap import ( - triton_kernel_wrapper_functional, - triton_kernel_wrapper_mutation, -) +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation from torch._prims_common import ( canonicalize_dim, canonicalize_dims, @@ -104,11 +101,32 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A _maybe_layout_constraints[fn] = None return None # We lazily register tag-based layout constraints. - if torch._C.Tag.needs_fixed_stride_order in fn.tags: - _maybe_layout_constraints[fn] = constrain_to_fx_strides - return _maybe_layout_constraints[fn] - _maybe_layout_constraints[fn] = None - return None + + def handle_layout_constraint_tag(tag): + if tag is torch._C.Tag.needs_fixed_stride_order: + _maybe_layout_constraints[fn] = constrain_to_fx_strides + return _maybe_layout_constraints[fn] + elif tag is torch._C.Tag.flexible_layout: + _maybe_layout_constraints[fn] = None + return None + else: + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + tag = get_layout_constraint_tag(fn) + return handle_layout_constraint_tag(tag) + + +def get_layout_constraint_tag(fn): + tags_by_priority = [ + torch._C.Tag.needs_fixed_stride_order, + torch._C.Tag.flexible_layout, + ] + for tag in tags_by_priority: + if tag in fn.tags: + return tag + if torch._library.utils.is_builtin(fn): + return torch._C.Tag.flexible_layout + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) def assert_nyi(cond, msg): @@ -536,7 +554,9 @@ def inner(*inputs: List[List[TensorBox]], alpha=1): def group_args(arg_pairs): out = defaultdict(list) for i, args in enumerate(arg_pairs): - use_foreach = not is_dynamic(*args) + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) device = None for t in args: if isinstance(t, TensorBox): @@ -1463,7 +1483,7 @@ def op_count(x): if not isinstance(x, ir.Pointwise): return 0 - count = x.inner_fn_opcount() + count = x.inner_fn_opcount().num_ops for read in x.get_read_names(): count += op_count(V.graph.get_buffer(read)) @@ -2094,6 +2114,8 @@ def apply_constraint(arg, fx_arg): if isinstance(arg, ir.IRNode): stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) return ir.ExternKernel.require_stride_order(arg, stride_order) + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()} return arg args = tuple( @@ -2213,7 +2235,6 @@ def is_aligned(x): # Need templated kernel make_fallback(aten.addbmm) -make_fallback(aten.addmv, warn=False) make_fallback(aten._addmm_activation, warn=False) # Need templated kernel. Probably impossible to write efficiently @@ -6081,13 +6102,17 @@ def set__source_tensor(self, source_tensor): return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) -if hasattr(torch.ops.fsdp, "set_"): +if hasattr(torch.ops.fsdp, "copy_"): - @register_lowering(torch.ops.fsdp.set_.default) - def fsdp_set_(self, source_tensor): - self.realize() - source_tensor.realize() - ir.SetSourceTensorKernel(self, source_tensor) + @register_lowering(torch.ops.fsdp.copy_.default) + def fsdp_copy_(dst, src): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) @register_lowering(torch.ops.aten.resize) @@ -6174,39 +6199,6 @@ def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs): return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} -@register_lowering(triton_kernel_wrapper_functional) -def triton_kernel_wrap( - *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone -): - new_kwargs = {} - for name, value in kwargs.items(): - if isinstance(value, ir.TensorBox): - x = value.data - has_non_rv_views = False - while isinstance(x, ir.BaseView): - if not isinstance(x, ir.ReinterpretView): - has_non_rv_views = True - break - x = x.data - if has_non_rv_views: - # we realize the inputs wrapped into any view which is not - # ReinterpretView to convert them into ReinterpretView during - # realization; all views being ReinterpretView is assumed by - # the downstream code (e.g., preserving ReinterpretView in - # cloning; layout should be available in mutation marking) - value = ir.TensorBox(ir.ExternKernel.realize_input(value)) - if name in tensors_to_clone: - value = clone_preserve_reinterpret_view(value) - new_kwargs[name] = value - - return triton_kernel_wrap_( - kernel_idx=kernel_idx, - constant_args_idx=constant_args_idx, - grid=grid, - kwargs=new_kwargs, - ) - - @register_lowering(torch.ops.higher_order.cond) def cond(pred, true_fn, false_fn, operands): if is_triton(pred) or any(map(is_triton, operands)): @@ -6232,12 +6224,12 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): @register_lowering(associative_scan_op, type_promotion_kind=None) -def associative_scan(combine_fn: ir.Subgraph, input, dim: int): +def associative_scan(combine_fn: ir.Subgraph, xs, dim: int): from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph subgraph_inputs = [ InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) - for x in itertools.chain(input, input) + for x in itertools.chain(xs, xs) ] lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] @@ -6247,9 +6239,9 @@ def wrapped_combine_fn(lhs, rhs): *pytree.tree_leaves(rhs), ) - kwargs = _make_scan_inner(input[0], axis=dim, dtype=None) - kwargs["dtypes"] = tuple(x.get_dtype() for x in input) - kwargs["inner_fns"] = tuple(x.make_loader() for x in input) + kwargs = _make_scan_inner(xs[0], axis=dim, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in xs) + kwargs["inner_fns"] = tuple(x.make_loader() for x in xs) result = ir.Scan.create( combine_fn=wrapped_combine_fn, can_fallback_to_aten=False, @@ -6265,7 +6257,7 @@ def _sink_tokens(tokens): return None -@register_lowering(torch.ops.higher_order.with_effects) +@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None) def with_effects(token, op, *args, **kwargs): result = ir.EffectfulKernel.create(op, *args, **kwargs) diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 5c26e322f12633..fe77279800e3da 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -49,6 +49,8 @@ class CppOuterLoopFusedCount: num_comprehensive_padding = 0 num_matches_for_scatter_upon_const_tensor = 0 +num_loop_reordering = 0 + # reset all counters def reset(): @@ -60,6 +62,7 @@ def reset(): global cpp_outer_loop_fused_inner_counts global num_comprehensive_padding global num_matches_for_scatter_upon_const_tensor + global num_loop_reordering generated_kernel_count = 0 generated_cpp_vec_kernel_count = 0 @@ -71,6 +74,7 @@ def reset(): cpp_outer_loop_fused_inner_counts.clear() num_comprehensive_padding = 0 num_matches_for_scatter_upon_const_tensor = 0 + num_loop_reordering = 0 @dataclass diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 0a8bbb9d1e6e17..df9e475be2796d 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -261,6 +261,8 @@ def codegen(self, wrapper): self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=[*self.inputs, *self.constant_args], ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -335,6 +337,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -428,6 +432,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + [*self.inputs, *self.constant_args], ) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: @@ -608,6 +614,7 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_arg_names = [ "x_scale", @@ -623,15 +630,27 @@ def codegen(self, wrapper): "scalars", "algorithm", ] + if not self.has_bias: + const_arg_names.insert(2, "bias") const_args = list(self.codegen_const_args(const_arg_names)) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[2] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[2] w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] ( x_scale, x_zp, + ) = const_args[:2] + ( + x_scale_raw, + x_zp_raw, + ) = self.constant_args[:2] + ( stride, padding, dilation, @@ -642,8 +661,19 @@ def codegen(self, wrapper): unary_attr, unary_scalars, unary_algorithm, - ) = const_args[-12:] - + ) = const_args[-10:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] codegen_args = ( x, x_scale, @@ -663,6 +693,25 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -670,6 +719,8 @@ def codegen(self, wrapper): codegen_args, self.cpp_op_schema, self.cpp_kernel_key, + op_overload=self.op_overload, + raw_args=raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -802,6 +853,7 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_arg_names = [ "x_scale", @@ -821,17 +873,35 @@ def codegen(self, wrapper): "unary_scalars", "unary_algorithm", ] + if not self.has_bias: + const_arg_names.insert(4, "bias") const_args = list(self.codegen_const_args(const_arg_names)) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] - bias = args[2] if self.has_bias else const_args[0] + packed_weight_raw = self.inputs[1] + bias = args[2] if self.has_bias else const_args[4] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[4] accum, w_scale, w_zp = args[-3], args[-2], args[-1] + accum_raw, w_scale_raw, w_zp_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) ( x_scale, x_zp, accum_scale, accum_zp, + ) = const_args[:4] + ( + x_scale_raw, + x_zp_raw, + accum_scale_raw, + accum_zp_raw, + ) = self.constant_args[:4] + ( stride, padding, dilation, @@ -844,7 +914,21 @@ def codegen(self, wrapper): unary_attr, unary_scalars, unary_algorithm, - ) = const_args[-16:] + ) = const_args[-12:] + ( + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] conv_args = ( x, x_scale, @@ -869,6 +953,30 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + accum_raw, + accum_scale_raw, + accum_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + stride_raw, + padding_raw, + dilation_raw, + groups_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -877,6 +985,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + op_overload=self.op_overload, + raw_args=raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -1219,17 +1329,23 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] + packed_weight_raw = self.inputs[1] bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] w_scale, w_zp = args[-2], args[-1] + w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] if self.x_scale_zp_are_tensors: assert len(args) >= 4 x_scale, x_zp = args[-4], args[-3] + x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3] ( o_scale, o_zp, @@ -1238,6 +1354,14 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-6:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-6:] else: assert len(const_args) >= 8 ( @@ -1250,6 +1374,16 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-8:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-8:] codegen_args = ( x, @@ -1266,6 +1400,21 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -1274,6 +1423,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -1396,17 +1547,27 @@ def __init__( def codegen(self, wrapper): # Parser the inputs and constant + # The raw_args setup can be skipped if there is a C shim implementation args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] + x_raw = self.inputs[0] packed_weight = args[1] + packed_weight_raw = self.inputs[1] bias = args[2] if self.has_bias else const_args[0] + bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] w_scale, w_zp, other = args[-3], args[-2], args[-1] + w_scale_raw, w_zp_raw, other_raw = ( + self.inputs[-3], + self.inputs[-2], + self.inputs[-1], + ) if self.x_scale_zp_are_tensors: assert len(args) >= 5 x_scale, x_zp = args[-5], args[-4] + x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4] ( o_scale, o_zp, @@ -1419,6 +1580,18 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-10:] + ( + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-10:] else: assert len(const_args) >= 8 ( @@ -1435,6 +1608,20 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) = const_args[-12:] + ( + x_scale_raw, + x_zp_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) = self.constant_args[-12:] codegen_args = ( x, @@ -1456,6 +1643,26 @@ def codegen(self, wrapper): unary_scalars, unary_algorithm, ) + raw_args = ( + x_raw, + x_scale_raw, + x_zp_raw, + packed_weight_raw, + w_scale_raw, + w_zp_raw, + other_raw, + bias_raw, + o_scale_raw, + o_zp_raw, + output_dtype_raw, + other_scale_raw, + other_zp_raw, + binary_attr_raw, + alpha_raw, + unary_attr_raw, + unary_scalars_raw, + unary_algorithm_raw, + ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.python_kernel_name, @@ -1464,6 +1671,8 @@ def codegen(self, wrapper): self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, + self.op_overload, + raw_args, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -1578,6 +1787,7 @@ def __init__( None, op_overload=torch.ops.aten.mkldnn_rnn_layer.default, ) + self.outputs: List[MultiOutput] = [] @classmethod def create( @@ -1647,11 +1857,14 @@ def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" return FlexibleLayout.contiguous_strides(output_shape) - output_sizes = [output_shape, hy_shape, cy_shape] + # C shim call requires all the outputs to be passed in, and thus the last + # dummy return value is added. + output_sizes = [output_shape, hy_shape, cy_shape, [1]] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), FlexibleLayout.contiguous_strides(hy_shape), FlexibleLayout.contiguous_strides(cy_shape), + [1], ] output_ir = [ MultiOutput( @@ -1668,5 +1881,10 @@ def get_strides_of_lstm_output(output_shape, batch_first): zip(output_sizes, output_strides) ) ] + packed.outputs = output_ir return output_ir + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + return super().codegen(wrapper) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index bf8f0d3d6370ca..c47ee1026ab919 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -5,7 +5,9 @@ Callable, Dict, Generic, + List, Literal, + NamedTuple, Optional, Tuple, TypeVar, @@ -19,6 +21,7 @@ import torch import torch.utils._pytree as pytree +from ..utils._ordered_set import OrderedSet from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str @@ -955,6 +958,13 @@ def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: return h +class OpCountResult(NamedTuple): + num_ops: int + used_ops: OrderedSet[str] + read_buffers: List[str] + nontrivial_read_count: int + + class OpCounterCSE: """Shim to count how many ops are used""" @@ -963,25 +973,67 @@ def __init__(self, inner): self.parent_handler = inner self.op_count = 0 self.var_names = {} + self._used_ops: OrderedSet[str] = OrderedSet() + self._read_names: List[str] = [] + self._nontrivial_read_count = 0 def __getattr__(self, name): def inner(*args, **kwargs): - val = getattr(self.parent_handler, name)(*args, **kwargs) - if name == "indirect_indexing": - return val + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) - def count(val): - if val not in self.var_names: - varname = f"tmp{self.op_count}" - self.op_count += 1 - self.var_names[val] = varname - return varname - else: - return self.var_names[val] + self._used_ops.add(name) + return inner - return pytree.tree_map(count, val) + def _update_count(self, val): + varname = self.var_names.get(val) + if not varname: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + + def indirect_indexing(self, *args, **kwargs): + self._used_ops.add("indirect_indexing") + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def load(self, name: str, index: sympy.Expr) -> str: + val = self.parent_handler.load(name, index) + if val not in self.var_names: + self._used_ops.add("load") + self._read_names.append(name) + if not isinstance(index, (sympy.Integer, int)): + self._nontrivial_read_count += 1 + return self._update_count(val) - return inner + def load_seed(self, name: str, offset: T): + val = self.parent_handler.load_seed(name, offset) + if val not in self.var_names: + self._used_ops.add("load_seed") + self._read_names.append(name) + return self._update_count(val) + + def bucketize( + self, + values, + offsets_name: str, + offsets_size: sympy.Expr, + indexing_dtype: torch.dtype, + right: bool, + ): + val = self.parent_handler.bucketize( + values, offsets_name, offsets_size, indexing_dtype, right + ) + if val not in self.var_names: + self._used_ops.add("bucketize") + self._read_names.append(offsets_name) + return self._update_count(val) + + def getvalue(self): + return OpCountResult( + self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count + ) def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index c333bd882942e5..96bf8641f3c9a6 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -6,7 +6,7 @@ import torch from torch.utils._sympy.value_ranges import ValueRanges -from .ir import LoopBody +from .loop_body import LoopBody from .utils import dominated_nodes diff --git a/torch/_inductor/package/__init__.py b/torch/_inductor/package/__init__.py index c088562100cdae..15587401b72358 100644 --- a/torch/_inductor/package/__init__.py +++ b/torch/_inductor/package/__init__.py @@ -1 +1 @@ -from .package import load_package, package_aoti +from .package import AOTICompiledModel, load_package, package_aoti diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index d1304293e3cd88..ca62b7172e664d 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -1,24 +1,25 @@ -import glob import json +import logging import os import shlex import subprocess -import tempfile import zipfile from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Dict, List, Optional, Union import torch import torch._inductor import torch.utils._pytree as pytree -from torch._inductor import config, exc +from torch._inductor import exc from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder from torch.export._tree_utils import reorder_kwargs -from .build_package import build_package_contents from .pt2_archive_constants import AOTINDUCTOR_DIR, ARCHIVE_VERSION +log = logging.getLogger(__name__) + + class PT2ArchiveWriter: def __init__(self, archive_path: str) -> None: self.archive_path: str = archive_path @@ -154,84 +155,71 @@ def get_aoti_file_with_suffix(suffix: str) -> str: return output_so -def package_aoti(aoti_output_dir: str) -> str: +def package_aoti(archive_file: str, aoti_files: Union[str, Dict[str, str]]) -> str: """ Saves the AOTInductor generated files to the PT2Archive format. - """ - - # Add a makefile and python script - build_package_filename = "build_package.py" - with open(os.path.join(aoti_output_dir, build_package_filename), "w") as f: - f.write(build_package_contents) - - with open(os.path.join(aoti_output_dir, "Makefile"), "w") as f: - f.write(f"all:\n\tpython3 {build_package_filename}\n") - - if config.aot_inductor.output_path.endswith(".so"): - raise RuntimeError( - "Unable to save package as a .so. It should be a .pt2 format or a directory." - ) - elif config.aot_inductor.output_path.endswith(".pt2"): - # Save using the PT2 packaging format - # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) - archive_path = config.aot_inductor.output_path - - with PT2ArchiveWriter(archive_path) as archive_writer: - package_files = glob.glob(f"{aoti_output_dir}/*") - - for path in package_files: - filename = os.path.basename(path) - archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path) - return archive_path - - else: - # Directly put the files into the directory, without any archiving - return aoti_output_dir + Args: + archive_file: The file name to save the package to. + aoti_files: This can either be a singular path to a directory containing + the AOTInductor files, or a dictionary mapping the model name to the + path to its AOTInductor generated files. + """ + if isinstance(aoti_files, str): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + assert archive_file.endswith(".pt2") + + # Save using the PT2 packaging format + # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) + + with PT2ArchiveWriter(archive_file) as archive_writer: + for model_name, aoti_output_dir in aoti_files.items(): + log.debug( + "Packaging AOTInductor files from %s with model name, %s", + aoti_output_dir, + model_name, + ) + for root, dirs, files in os.walk(aoti_output_dir): + for file in files: + log.debug( + "Saving AOTI generated file %s to archive in %s%s/%s", + os.path.join(root, file), + AOTINDUCTOR_DIR, + model_name, + file, + ) + archive_writer.write_file( + f"{AOTINDUCTOR_DIR}{model_name}/{file}", + os.path.join(root, file), + ) + return archive_file + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader -def load_package(path: str, device: str) -> Callable: # type: ignore[type-arg] - if path.endswith(".so"): - raise RuntimeError( - "Unable to load .so. It should be a .pt2 format or a directory." - ) - - elif path.endswith(".pt2"): - so_path = os.path.splitext(path)[0] - with PT2ArchiveReader(path) as archive_reader: - file_names = archive_reader.get_file_names() - - with tempfile.TemporaryDirectory() as tmp_dir: - archive_reader.extractall(tmp_dir) - file_names = archive_reader.get_file_names() - aoti_files = [ - file for file in file_names if file.startswith(AOTINDUCTOR_DIR) - ] - - so_path = compile_so(tmp_dir, aoti_files, so_path) - - else: - assert os.path.isdir(path), "Must specify a directory or a .pt2 file" - aoti_files = [ - os.path.join(root, file) - for root, dirs, files in os.walk(path) - for file in files - ] - so_path = compile_so(path, aoti_files, path) - - if device == "cpu": - runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] - elif device == "cuda" or device.startswith("cuda:"): - runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] - else: - raise RuntimeError("Unsupported device " + device) - - def optimized(*args, **kwargs): # type: ignore[no-untyped-def] - call_spec = runner.get_call_spec() # type: ignore[attr-defined] + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() # type: ignore[attr-defined] in_spec = pytree.treespec_loads(call_spec[0]) out_spec = pytree.treespec_loads(call_spec[1]) flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] - flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + flat_outputs = self.loader.run(flat_inputs) # type: ignore[attr-defined] return pytree.tree_unflatten(flat_outputs, out_spec) - return optimized + def get_metadata(self) -> Dict[str, str]: + return self.loader.get_metadata() # type: ignore[attr-defined] + + +def load_package(path: str, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] + if not path.endswith(".pt2"): + raise RuntimeError("Unable to load package. Path must be a .pt2 file.") + + loader = torch._C._aoti.AOTIModelPackageLoader(path, model_name) # type: ignore[call-arg] + return AOTICompiledModel(loader) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 9c78d09ae9ab0b..e43d37fd37b1a7 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -200,7 +200,8 @@ def bundle(self) -> Match: def __repr__(self) -> str: return f"Match(..., {self.args}, {self.kwargs})" - def erase_nodes(self, graph: torch.fx.Graph) -> None: + def erase_nodes(self) -> None: + graph = self.graph for n in reversed(self.nodes): if not n._erased and not n.users: graph.erase_node(n) @@ -236,9 +237,13 @@ def replace_by_example( replacement graph. """ - from torch._inductor.virtualized import V + from torch._inductor.virtualized import NullHandler, V - context = V.fake_mode if V.fake_mode is not None else contextlib.nullcontext + context = ( + V.fake_mode + if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) + else contextlib.nullcontext() + ) with context: if trace_fn is None: @@ -1005,7 +1010,7 @@ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> Non replacement.meta.update(node.meta) node.replace_all_uses_with(replacement) assert match.nodes[-1] is node - match.erase_nodes(graph) + match.erase_nodes() @dataclasses.dataclass @@ -1032,9 +1037,6 @@ def replace_with_graph( replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], args: Sequence[torch.fx.Node], ) -> None: - output_nodes = match.output_nodes() - first_node = output_nodes[0] - class Replacer(torch.fx.Interpreter): call_method = None # type: ignore[assignment] call_module = None # type: ignore[assignment] @@ -1176,7 +1178,7 @@ def replace( assert len(output_nodes) == 1 replace(output_nodes[0], replacement) - match.erase_nodes(graph) + match.erase_nodes() def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: assert match.replacement_graph is not None diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 391192eaac6832..056e24f1b4e26c 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -1,32 +1,140 @@ +from __future__ import annotations + +import json import os +import typing from abc import abstractmethod -from typing import Optional +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing_extensions import override, TypeAlias + +from torch._inductor import config + + +try: + import redis +except ImportError: + redis = None # type: ignore[assignment] + + +if config.is_fbcode(): + from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] + Sample as Sample_, + ) + Sample: TypeAlias = Sample_ +else: + Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef] -class RemoteCacheBackend: + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class RemoteCacheBackend(Generic[_T]): """ - A backend implementation for accessing a remote/distributed cache. + A backend implementation for accessing a remote/distributed cache. Only + works with bytes in/out. For structured data use a RemoteCache. """ - def __init__(self, cache_id: str) -> None: + @abstractmethod + def get(self, key: str) -> Optional[_T]: pass @abstractmethod - def get(self, key: str) -> Optional[object]: + def put(self, key: str, data: _T) -> None: pass + +# Serde that encodes from _T to _U and decodes from _U to _T. +class RemoteCacheSerde(Generic[_T, _U]): @abstractmethod - def put(self, key: str, data: bytes) -> None: + def encode(self, data: _T) -> _U: + pass + + @abstractmethod + def decode(self, data: _U) -> _T: + pass + + +JsonDataTy = Optional[ + Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]] +] + + +class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): + def encode(self, data: JsonDataTy) -> bytes: + return bytes(json.dumps(data), "ascii") + + def decode(self, data: bytes) -> JsonDataTy: + return json.loads(data) + + +class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): + def encode(self, data: _T) -> _T: + return data + + def decode(self, data: _T) -> _T: + return data + + +class RemoteCache(Generic[_T]): + backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None + + def __init__( + self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] + ) -> None: + # Support for testing. + if (override_cls := self.__class__.backend_override_cls) is not None: + self.backend = override_cls() + else: + self.backend = backend + self.serde = serde + + def get(self, key: str) -> Optional[_T]: + sample = self._create_sample() + result = self._get(key, sample) + self._log_sample(sample) + return result + + def put(self, key: str, value: _T) -> None: + sample = self._create_sample() + self._put(key, value, sample) + self._log_sample(sample) + + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] + return self.serde.decode(data) # type: ignore[arg-type] + + def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U + return self.serde.encode(value) + + def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: + if data := self.backend.get(key): + return self._decode(data, sample) + return None + + def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: + data = self._encode(value, sample) + self.backend.put(key, data) + + def _create_sample(self) -> Optional[Sample]: + return None + + def _log_sample(self, sample: Optional[Sample]) -> None: pass -class RedisRemoteCacheBackend(RemoteCacheBackend): +class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): """ A Redis implementation of a remote/distributed cache. """ + _key_fmt: str + _redis: Optional[redis.Redis] = None + def __init__(self, cache_id: str) -> None: - import redis + if not redis: + # We had trouble importing redis - just skip init. + return self._key_fmt = f"pt2:{cache_id}:{{key}}" self._redis = redis.Redis( @@ -34,14 +142,57 @@ def __init__(self, cache_id: str) -> None: port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), ) - def _get_key(self, key: str) -> str: + def __get_key(self, key: str) -> str: return self._key_fmt.format(key=key) + @override def get(self, key: str) -> Optional[bytes]: - value = self._redis.get(self._get_key(key)) + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return None + + try: + value = self._redis.get(self.__get_key(key)) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + return None + # In theory redis.get() can return an Awaitable as well... assert value is None or isinstance(value, bytes) return value + @override def put(self, key: str, data: bytes) -> None: - self._redis.set(self._get_key(key), data) + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return + + try: + self._redis.set(self.__get_key(key), data) + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + + +class RedisRemoteCache(RemoteCache[JsonDataTy]): + def __init__(self, key: str) -> None: + # Special test handling: If we're just going to override the backend + # anyway don't require redis + if self.__class__.backend_override_cls: + # This is totally bogus but it works for now... + backend = typing.cast(RemoteCacheBackend[bytes], None) + else: + backend = RedisRemoteCacheBackend(key) + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + +class RemoteAutotuneCache(RedisRemoteCache): + pass + + +class RemoteFxGraphCache(RedisRemoteCache): + pass diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py new file mode 100644 index 00000000000000..65dfc73d63d72d --- /dev/null +++ b/torch/_inductor/runtime/autotune_cache.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import os +import os.path +from typing import Dict, List, Optional, Tuple +from typing_extensions import override + +import torch +from torch.utils._triton import has_triton_package + +from ..remote_cache import ( + JsonDataTy, + RemoteCache, + RemoteCacheBackend, + RemoteCacheJsonSerde, +) + + +if has_triton_package(): + from triton import Config + +log = logging.getLogger(__name__) + + +_InductorMetaTy = Dict[str, object] + + +@dataclasses.dataclass +class AutotuneCache: + configs_hash: str + filename: str + local_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None + remote_cache: Optional[Tuple[RemoteCache[JsonDataTy], str]] = None + + # Create a AutotuneCache. Returns None if none of the caches can be used. + @staticmethod + def create( + inductor_meta: _InductorMetaTy, filename: str, configs_hash: str + ) -> Optional[AutotuneCache]: + cache = AutotuneCache(configs_hash, filename) + cache._setup_local_cache(inductor_meta, filename) + cache._setup_remote_autotune_cache(inductor_meta, filename) + if cache.local_cache or cache.remote_cache: + return cache + else: + return None + + # Read the best config options from the most local cache and return it. + def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy]]: + if local_cache := self.local_cache: + cache, key = local_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + if remote_cache := self.remote_cache: + cache, key = remote_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + return None + + # Read the best config options from the most local cache and figure out + # which `configs` represents that option. + def read_best( + self, inductor_meta: _InductorMetaTy, configs: List[Config] + ) -> Optional[Config]: + if best := self._read(inductor_meta): + return _load_cached_autotuning( + best, self.configs_hash, configs, inductor_meta + ) + return None + + # Set up local filesystem caching information + def _setup_local_cache(self, inductor_meta: _InductorMetaTy, filename: str) -> None: + if not inductor_meta.get("autotune_local_cache", True): + return + + cache_filename = os.path.splitext(filename)[0] + ".best_config" + local_cache = RemoteCache(_LocalAutotuneCacheBackend(), RemoteCacheJsonSerde()) + self.local_cache = (local_cache, cache_filename) + + # Set up remote caching information + def _setup_remote_autotune_cache( + self, inductor_meta: _InductorMetaTy, filename: str + ) -> None: + if not _should_use_remote_autotune_cache(inductor_meta): + return + + remote_cache = _create_cache( + inductor_meta, + self.configs_hash, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + "autotune-best-config-v2", + ) + if not remote_cache: + return + + # we already sha256 hash the source contents + remote_cache_key = os.path.basename(filename) + self.remote_cache = (remote_cache, remote_cache_key) + + # Save the config in the caches + def save( + self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False + ) -> None: + data = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + "configs_hash": self.configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + } + + if local_cache := self.local_cache: + cache, key = local_cache + cache.put(key, data) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, key) + + if remote_cache := self.remote_cache: + cache, key = remote_cache + cache.put(key, data) + + +def _should_use_remote_autotune_cache(inductor_meta: Dict[str, object]) -> bool: + if (config := inductor_meta.get("autotune_remote_cache")) is not None: + return bool(config) + if not inductor_meta.get("is_fbcode"): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:autotune_memcache_version" + ) + + +def _load_cached_autotuning( + best_config: Dict[str, JsonDataTy], + configs_hash: str, + configs: List[Config], + inductor_meta: Dict[str, object], +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( + "found_by_coordesc", False + ): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +def _create_cache( + inductor_meta: Dict[str, object], + configs_hash: str, + fb_cache_cls: str, + oss_cache_cls: str, + salt: str, +) -> Optional[RemoteCache[JsonDataTy]]: + backend_hash = inductor_meta.get("backend_hash", None) + if backend_hash is None: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + return None + + assert isinstance(backend_hash, str) + + key = backend_hash + configs_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + try: + if inductor_meta.get("is_fbcode"): + import torch._inductor.fb.remote_cache + + cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls) + return cache_cls(key) + else: + import torch._inductor.remote_cache + + cache_cls = getattr(torch._inductor.remote_cache, oss_cache_cls) + return cache_cls(key) + except Exception: + log.warning("Unable to create a remote cache", exc_info=True) + return None + + +class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): + @override + def get(self, key: str) -> Optional[bytes]: + try: + with open(key, "rb") as fd: + return fd.read() + except FileNotFoundError: + return None + + @override + def put(self, key: str, data: bytes) -> None: + with open(key, "wb") as fd: + fd.write(data) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 90f9e5a1cc5966..0f1495e49972c7 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -8,7 +8,7 @@ # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 2048, + "X": 4096, "Y": 1024, "Z": 1024, "R": 4096 * 16, # * 16 is multi-kernel only @@ -104,24 +104,32 @@ class DeviceProperties(typing.NamedTuple): regs_per_multiprocessor: Optional[int] = None max_threads_per_multi_processor: Optional[int] = None multi_processor_count: Optional[int] = None + warp_size: Optional[int] = None @classmethod def create(cls, device): import torch from torch._dynamo.device_interface import get_interface_for_device - device_type = device.type if torch.version.hip is None else "hip" + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + device_interface = get_interface_for_device(device) - if device_type == "cuda": + if device_type in ["cuda", "hip"]: props = device_interface.get_device_properties(device) return cls( type=device_type, index=device.index, cc=device_interface.get_compute_capability(device), major=props.major, - regs_per_multiprocessor=props.regs_per_multiprocessor, + regs_per_multiprocessor=props.regs_per_multiprocessor + if hasattr(props, "regs_per_multiprocessor") + else None, max_threads_per_multi_processor=props.max_threads_per_multi_processor, multi_processor_count=props.multi_processor_count, + warp_size=props.warp_size, ) return cls( type=device_type, diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 40e2678e089b54..446dbc71c61d1d 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -65,6 +65,17 @@ def triton_config_to_hashable(cfg): return tuple(items) +def validate_triton_config(cfg): + # [Note: Triton pre_hook in inductor] + # pre-hook is a lambda function, which we don't attempt to serialize. + # right now, if a pre-hook is attached to the config, it will not be saved; + # and then it won't be used when the config is loaded from cache. + # So we assert - if we do get a pre_hook, it might get ignored after caching. + assert ( + getattr(cfg, "pre_hook", None) is None + ), "triton configs with pre_hooks not supported" + + def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True): info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" slow = ms > 0.012 and gb_per_s < 650 diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 824ecf241d5dd4..abb932266c5a89 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs +from __future__ import annotations + import builtins import copy import functools import hashlib import inspect -import json import logging import math import operator @@ -14,10 +15,11 @@ import sys import threading import time -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Set, Tuple import torch +from .autotune_cache import AutotuneCache from .benchmarking import benchmarker from .coordinate_descent_tuner import CoordescTuner from .hints import ( @@ -40,12 +42,10 @@ get_num_bytes, next_power_of_2, triton_config_to_hashable, + validate_triton_config, ) -if TYPE_CHECKING: - from ..remote_cache import RemoteCacheBackend - try: import triton except ImportError: @@ -172,7 +172,7 @@ def __init__( triton_meta, # passed directly to triton configs, save_cache_hook, - mutated_arg_names, + mutated_arg_names: List[str], # see [Note: clone mutated buffers] heuristic_type, size_hints=None, inductor_meta=None, # metadata not relevant to triton @@ -182,6 +182,10 @@ def __init__( super().__init__() assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + # makes sure there are no pre-hooks on any of the triton configs + for cfg in configs: + validate_triton_config(cfg) + self.fn = fn self.device_props: DeviceProperties = triton_meta["device"] self.triton_meta = { @@ -261,10 +265,11 @@ def precompile(self, warm_cache_only=False): self.inductor_meta.get("dynamic_scale_rblock", True) and self.heuristic_type == HeuristicType.REDUCTION and self.size_hints is not None - # Disable for AMDGPU/Intel as Triton is not ready to return n_regs for a compiled_binary. - and device_prop.type == "cuda" + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] and device_prop.major - and device_prop.major >= 8 + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None ): assert device_prop.regs_per_multiprocessor assert device_prop.max_threads_per_multi_processor @@ -301,7 +306,7 @@ def precompile(self, warm_cache_only=False): ): continue - nreg_per_warp = nreg * 32 + nreg_per_warp = nreg * device_prop.warp_size nreg_per_block = nreg_per_warp * triton_config.num_warps # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' @@ -404,7 +409,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): "num_stages": compile_meta["num_stages"], "debug": compile_meta["debug"], } - if self.device_props.type != "hip": + if self.device_props.type == "hip": if "waves_per_eu" in compile_meta: options["waves_per_eu"] = compile_meta["waves_per_eu"] if "matrix_instr_nonkdim" in compile_meta: @@ -655,11 +660,6 @@ def bench(self, launcher, *args, grid, with_profiler=False, **kwargs): stream = device_interface.get_raw_stream(device_interface.current_device()) def kernel_call(): - if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} - ) - cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) launcher( *cloned_args, @@ -673,11 +673,12 @@ def kernel_call(): return do_bench_using_profiling(kernel_call, warmup=10, rep=40) - return benchmarker.benchmark_gpu(kernel_call, rep=40, fast_flush=True) + return benchmarker.benchmark_gpu(kernel_call, rep=40) def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]: from ..compile_fx import clone_preserve_strides + # [Note: clone mutated buffers] # clone inplace buffers to avoid autotune contaminating them if # the kernel does in-place stores. avoid cloning other buffers because # it leads to increase memory use @@ -826,7 +827,7 @@ def benchmark_one_config(config): ) return config2launcher.get(best_config) - def run(self, *args, grid, stream, **kwargs): + def run(self, *args, grid, stream, **kwargs): # type:ignore[override] if len(self.launchers) != 1: if len(self.launchers) == 0: start_time = time.time_ns() @@ -848,11 +849,6 @@ def run(self, *args, grid, stream, **kwargs): if launcher.store_cubin: self.save_gpu_kernel(grid, stream, launcher) - if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs} - ) - if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1": _dump_launch_params(args, kwargs, launcher, self.fn.__name__) @@ -962,7 +958,7 @@ def __init__(self, *args, regex_filter="", with_profiler=False, **kwargs): super().__init__(*args, **kwargs) self.cached = None - def run(self, *args, grid, stream): + def run(self, *args, grid, stream): # type: ignore[override] possible_names = _find_names(self) kernel_name = f"{max(possible_names, key=len)}" if not re.match(self.regex_filter, kernel_name): @@ -1006,74 +1002,6 @@ def hash_configs(configs: List[Config]): return hasher.hexdigest() -def load_cached_autotuning( - best_config, - configs_hash: str, - configs: List[Config], - inductor_meta: Dict[str, Any], -): - if best_config is None: - return None - if best_config.pop("configs_hash", None) != configs_hash: - return None - - # Remove time taken for comparison - best_config.pop("time_taken_ms", None) - - if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( - "found_by_coordesc", False - ): - num_warps = best_config.pop("num_warps") - num_stages = best_config.pop("num_stages") - triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages) - triton_config.found_by_coordesc = True - return triton_config - - matching_configs = [ - cfg - for cfg in configs - if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) - and cfg.num_warps == best_config.get("num_warps") - and cfg.num_stages == best_config.get("num_stages") - ] - if len(matching_configs) != 1: - return None - - return matching_configs[0] - - -def should_use_remote_autotune_cache(inductor_meta): - if inductor_meta.get("autotune_remote_cache") is not None: - return inductor_meta.get("autotune_remote_cache") - if not inductor_meta.get("is_fbcode"): - return False - if torch._utils_internal.is_fb_unit_test(): - return False - if inductor_meta.get("is_hip"): - return False - - try: - from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION - except ModuleNotFoundError: - return False - - return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( - "pytorch/remote_cache:autotune_memcache_version" - ) - - -class LocalAutotuneCache: - def get(self, filename): - if os.path.exists(filename): - with open(filename) as fd: - return json.loads(fd.read()) - return None - - def put(self, filename, data): - with open(filename, "w") as fd: - fd.write(json.dumps(data)) - - def cached_autotune( size_hints: Optional[List[int]], configs: List[Config], @@ -1089,92 +1017,27 @@ def cached_autotune( """ configs = unique_configs(configs) assert len(configs) == 1 or filename - save_cache_hook: Optional[Callable[[Any, Any, Any], Any]] inductor_meta = {} if inductor_meta is None else inductor_meta + disabled = inductor_meta.get("force_disable_caches", False) + # on disk caching logic and/or remote caching - if filename is not None and ( - len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning") + autotune_cache = None + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) ): configs_hash = hash_configs(configs) - local_cache = None - cache_filename = None - remote_cache: Optional[RemoteCacheBackend] = None - remote_cache_key = None - best_config = None - if not inductor_meta.get("force_disable_caches", False): - if inductor_meta.get("autotune_local_cache", True): - local_cache = LocalAutotuneCache() - cache_filename = os.path.splitext(filename)[0] + ".best_config" - if should_use_remote_autotune_cache(inductor_meta): - backend_hash = inductor_meta.get("backend_hash", None) - if backend_hash is not None: - key = backend_hash + configs_hash + "autotune-best-config-v2" - key = hashlib.sha256(key.encode("utf-8")).hexdigest() - - try: - if inductor_meta.get("is_fbcode"): - from torch._inductor.fb.remote_cache import ( - FbRemoteAutotuneCacheBackend, - ) - - remote_cache = FbRemoteAutotuneCacheBackend(key) - else: - from torch._inductor.remote_cache import ( - RedisRemoteCacheBackend, - ) - - remote_cache = RedisRemoteCacheBackend(key) - except Exception: - remote_cache = None - log.warning("Unable to create a remote cache", exc_info=True) - # we already sha256 hash the source contents - remote_cache_key = os.path.basename(filename) - else: - log.debug( - "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" - ) - - best_config = None - if local_cache is not None and cache_filename is not None: - best_config = local_cache.get(cache_filename) - if ( - remote_cache is not None - and remote_cache_key is not None - and best_config is None - ): - best_config = remote_cache.get(remote_cache_key) - - best_config = load_cached_autotuning( - best_config, configs_hash, configs, inductor_meta - ) - if best_config: + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): configs = [best_config] - else: - log.debug("autotune caching is disabled by config.force_disable_caches") - - def save_cache_hook(cfg, time_taken_ns, found_by_coordesc=False): - data = { - **cfg.kwargs, - "num_warps": cfg.num_warps, - "num_stages": cfg.num_stages, - "configs_hash": configs_hash, - "found_by_coordesc": found_by_coordesc, - "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS - } - if local_cache is not None and cache_filename is not None: - local_cache.put(cache_filename, data) - if remote_cache is not None and remote_cache_key is not None: - remote_cache.put(remote_cache_key, data) # type: ignore[arg-type] - - if log.isEnabledFor(logging.DEBUG): - type_str = "coordesc" if found_by_coordesc else "heuristic" - log.debug("Save %s tuning result to %s", type_str, cache_filename) - else: - save_cache_hook = None + if disabled: + log.debug("autotune caching is disabled by config.force_disable_caches") mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) @@ -1201,7 +1064,7 @@ def decorator(fn): "profile_bandwidth_with_do_bench_using_profiling" ], configs=configs, - save_cache_hook=save_cache_hook, + save_cache_hook=autotune_cache and autotune_cache.save, mutated_arg_names=mutated_arg_names, heuristic_type=heuristic_type, size_hints=size_hints, @@ -1213,7 +1076,7 @@ def decorator(fn): triton_meta=triton_meta, inductor_meta=inductor_meta, configs=configs, - save_cache_hook=save_cache_hook, + save_cache_hook=autotune_cache and autotune_cache.save, mutated_arg_names=mutated_arg_names, heuristic_type=heuristic_type, size_hints=size_hints, @@ -1278,10 +1141,10 @@ def _check_max_grid_x(size_hints, x, num_warps): while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints[0]: x *= 2 # Scale up XBLOCK if grid exceeds limits num_blocks = num_blocks // 2 - if x >= max_grid_x: - raise AssertionError( - "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" - ) + if (num_blocks * num_warps * warp_size) > max_grid_x: + raise AssertionError( + "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" + ) return x, num_blocks @@ -1874,16 +1737,38 @@ def grid_fn(meta): return grid_fn -def grid_combo_kernels(*numels, num_kernels, min_blocks, is_sequential): +def grid_combo_kernels( + *numels, num_kernels, min_blocks, is_sequential, default_meta=None +): """min_blocks is the minimal size of the grid x dimension""" if not is_sequential: # round robin dispatch - kernel_grid_fn = grid(*numels) + numels_agg = list(numels) + for i in range(len(numels_agg)): + if isinstance(numels_agg[i], (list, tuple)): + numels_agg[i] = max(max(numels_agg[i]), 0) # noqa: PLW3301 + kernel_grid_fn = grid(*numels_agg) + + if isinstance(numels[-1], (list, tuple)): + min_blocks_d = max(-min(numels[-1]), 0) * num_kernels + else: + min_blocks_d = None + if min_blocks is None: + assert min_blocks_d is not None + min_blocks = min_blocks_d + else: + assert ( + min_blocks_d is None or min_blocks == min_blocks_d + ), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" else: # sequential dispatch seq_numels = list(numels) # x numels are not used here, just a place holder seq_numels[-1] = 1024 + for i in range(len(seq_numels) - 1): + if isinstance(seq_numels[i], (list, tuple)): + seq_numels[i] = max(seq_numels[i]) + kernel_grid_fn = grid(*seq_numels) def get_grid_dim(numel, block): @@ -1894,6 +1779,7 @@ def get_grid_dim(numel, block): return ceildiv(numel, block) def grid_fn(meta): + assert min_blocks is not None, "min_blocks must be a number" cuda_grid = list(kernel_grid_fn(meta)) cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks) return tuple(cuda_grid) @@ -1910,4 +1796,13 @@ def seq_grid_fn(meta): cuda_grid[0] = x_grid return tuple(cuda_grid) - return grid_fn if not is_sequential else seq_grid_fn + def grid_fn_default_meta(meta): + return grid_fn(default_meta) + + def seq_grid_fn_default_meta(meta): + return seq_grid_fn(default_meta) + + if default_meta is None: + return grid_fn if not is_sequential else seq_grid_fn + else: + return grid_fn_default_meta if not is_sequential else seq_grid_fn_default_meta diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 97953f7866d016..3d2676e0b2e000 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -11,7 +11,9 @@ import os import pprint import textwrap +import traceback import typing +from collections import defaultdict from typing import ( Any, Callable, @@ -45,6 +47,7 @@ from .comm_analysis import estimate_nccl_collective_runtime from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .ir import ComputedBuffer, MultiOutput, MultiOutputLayout +from .loop_body import LoopBody from .runtime.runtime_utils import green_text, red_text from .sizevars import SimplifyIndexing from .utils import ( @@ -65,6 +68,7 @@ log = logging.getLogger(__name__) fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") @dataclasses.dataclass @@ -156,27 +160,27 @@ class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites unmet_dependencies: OrderedSet[Dep] - - def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. + # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node + # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. + # For non-"grouped" nodes (i.e. regular SchedulerNode), + # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. + min_order: int + max_order: int + + def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler + + def _init_from_node(self, node: ir.Operation) -> None: self.node: Optional[ir.Operation] = node - self.set_read_writes(node.get_read_writes()) self.ancestors: OrderedSet[str] = OrderedSet() - # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. - # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node - # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. - # For non-"grouped" nodes (i.e. regular SchedulerNode), - # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. - self.min_order: int - self.max_order: int self.last_usage: OrderedSet[ str ] = OrderedSet() # buffers that won't be used after this kernel self.written = False - self.outputs: List[SchedulerBuffer] = [ SchedulerBuffer( - scheduler=scheduler, + scheduler=self.scheduler, node=output, defining_op=self, ) @@ -240,6 +244,11 @@ def log_details(self) -> None: self.read_writes.writes, ) + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + return + def update_mutated_names(self, renames: Dict[str, str]) -> None: self.set_read_writes(self.read_writes.rename(renames)) @@ -267,9 +276,6 @@ def mark_run(self) -> None: for buf in self.outputs: buf.allocate() - def op_counts(self) -> Counter[str]: - return self.read_writes.op_counts - def used_buffer_names(self) -> OrderedSet[str]: return OrderedSet( dep.name @@ -324,11 +330,13 @@ def get_name(self) -> str: def get_first_name(self) -> str: return self.get_name() + @cache_on_self def get_operation_names(self) -> OrderedSet[str]: - return OrderedSet(node.get_name() for node in self.get_nodes()) + return OrderedSet([node.get_name() for node in self.get_nodes()]) + @cache_on_self def get_buffer_names(self) -> OrderedSet[str]: - return OrderedSet(out.get_name() for out in self.outputs) + return OrderedSet([out.get_name() for out in self.outputs]) def get_nodes(self) -> Sequence[BaseSchedulerNode]: return [self] @@ -497,6 +505,7 @@ def codegen_originating_info( buffer.writelines(out_lines) self.written = True + @cache_on_self def get_read_write_buffers_sizes(self) -> int: """ Counting the number of bytes accessed for a kernel is @@ -602,6 +611,7 @@ def get_buf_bytes(buf: Optional[Union[ir.Buffer, ir.TensorBox]]) -> int: return node_bytes + @cache_on_self def get_estimated_runtime(self) -> float: """ Returns estimated op runtime in nanoseconds (ns) @@ -794,6 +804,11 @@ def should_prune(dep: Dep) -> bool: class ExternKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + def debug_str_extra(self) -> str: return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" @@ -806,7 +821,10 @@ def has_side_effects(self) -> bool: class NopKernelSchedulerNode(BaseSchedulerNode): - pass + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) class SchedulerNode(BaseSchedulerNode): @@ -815,34 +833,94 @@ def __init__( scheduler: Scheduler, node: Union[ir.ComputedBuffer, ir.TemplateBuffer], ) -> None: - super().__init__(scheduler, node) + super().__init__(scheduler) + self._init_from_node(node) self._compute_attrs() def _compute_attrs( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) self._sizes, self._body = self.node.simplify_and_reorder( - extra_indexing_constraints=extra_indexing_constraints + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, ) group_fn = self.scheduler.get_backend(self.node.get_device()).group_fn self.group = (self.node.get_device(), group_fn(self._sizes)) + # Don't normalize since normalization will merge loops which + # makes it hard to decide new loop orders. + should_normalize = ( + not config.loop_ordering_after_fusion + or self.node.get_device().type != "cuda" + ) + if isinstance(self.node, ir.TemplateBuffer): - self.set_read_writes(self.node.normalized_read_writes()) + self.set_read_writes( + self.node.extract_read_writes(normalize=should_normalize) + ) else: self.set_read_writes( dependencies.extract_read_writes( - self._body, *self._sizes, normalize=True + self._body, *self._sizes, normalize=should_normalize ) ) def recompute_size_and_body( - self, extra_indexing_constraints: Tuple[Dict[Any, Any], List[Any]] + self, + extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: - self._compute_attrs(extra_indexing_constraints=extra_indexing_constraints) + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + def refresh_dependencies(self, normalize: bool) -> None: + # Fake dependencies are added manually. They can not be analyzed from + # extract_read_writes. Find them out and apply manually. + fake_deps = { + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + } + + # don't normalize since the loop order may need to be further changed + # later + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=normalize + ).with_read(fake_deps) + ) + + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: + self._body = self._body.reorder_iter_loops( + new_order, + ) + self._sizes = self._body.sizes + + self.refresh_dependencies(normalize=False) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + new_order = None + self_sizes = self._sizes[0] + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if new_order: + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for %s with order %s", self.get_name(), new_order + ) + self.apply_new_loop_order(new_order) + else: + loop_ordering_log.debug( + "Don't reordering %s because we can not decide the suitable loop order", + self.get_name(), + ) def debug_str_extra(self) -> str: name = self.get_name() @@ -856,7 +934,7 @@ def debug_str_extra(self) -> str: buf_name = dep.name buf = V.graph.get_buffer(buf_name) lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") - if isinstance(self._body, ir.LoopBody): + if isinstance(self._body, LoopBody): lines.append(f"class {name}_loop_body:") lines.append(textwrap.indent(self._body.debug_str(), " ")) @@ -918,16 +996,15 @@ def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: log.fatal("Error in codegen for %s", self.node) raise + @cache_on_self def pointwise_read_writes(self) -> dependencies.ReadWrites: """ Get the memory dependencies in the non-reduction axis. """ sizes, reduction_sizes = self._sizes - - def fn(index: Sequence[sympy.Symbol]) -> str: - return self._body(index, [sympy.Integer(0) for _ in reduction_sizes]) - - return dependencies.extract_read_writes(fn, sizes) + return dependencies.extract_read_writes( + self._body, sizes, hidden_args=[[sympy.Integer(0)] * len(reduction_sizes)] + ) def can_inplace(self, read_dep: dependencies.Dep) -> bool: if self.is_template(): @@ -945,7 +1022,7 @@ def can_inplace(self, read_dep: dependencies.Dep) -> bool: @cache_on_self def _get_atomic_add_buffers(self) -> OrderedSet[str]: buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() - if isinstance(self._body, ir.LoopBody): + if isinstance(self._body, LoopBody): for node in self._body.get_nodes(): if ( node.op == "call_method" @@ -963,6 +1040,22 @@ def _get_atomic_add_buffers(self) -> OrderedSet[str]: return buffers_store_as_atomic_add +def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None: + snodes = group_snode.snodes # type: ignore[attr-defined] + group_snode.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + group_snode.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in group_snode.get_buffer_names() + ) + - group_snode.read_writes.writes + ) + + def init_group_node( group_snode: BaseSchedulerNode, scheduler: Scheduler, @@ -976,18 +1069,7 @@ def init_group_node( *[x.ancestors for x in snodes if x.ancestors is not None] ) - group_snode.set_read_writes( - dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) - ) - - group_snode.unmet_dependencies = ( - OrderedSet( - dep - for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) - if dep.name not in group_snode.get_buffer_names() - ) - - group_snode.read_writes.writes - ) + refresh_group_node_dependencies(group_snode) group_snode.min_order = min(x.min_order for x in group_snode.snodes) group_snode.max_order = max(x.max_order for x in group_snode.snodes) @@ -1015,8 +1097,45 @@ def fuse( nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) return cls(node1.scheduler, nodes) + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> None: + if self.is_template(): + # We can not really reorder loops for a triton template + return + self_sizes = None + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + if self_sizes is not None and self_sizes != snode._sizes[0]: + loop_ordering_log.debug( + "Can not reorder fused node due to different sizes" + ) + return + self_sizes = snode._sizes[0] + new_order = None + + assert self_sizes is not None + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if not new_order: + loop_ordering_log.debug( + "Dont reordering fused node %s because we can not decide the suitable loop order", + self.get_name(), + ) + return + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for fused node %s with order %s", self.get_name(), new_order + ) + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + snode.apply_new_loop_order(new_order) # type: ignore[arg-type] + + refresh_group_node_dependencies(self) + def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: - # NB: No need to call super().__init__() because we don't need to re-use any of its logic. + super().__init__(scheduler) init_group_node(self, scheduler, snodes) self.users: List[NodeUser] = [] self.group = max(snodes, key=lambda x: int(x.is_reduction())).group @@ -1044,10 +1163,10 @@ def debug_str_extra(self) -> str: for i, node in enumerate(self.snodes) ] node = self.snodes[0].node - assert node is not None - device = node.get_device() - if ir.is_triton(device): - lines.extend(debug_triton_code(self)) + if node is not None: + device = node.get_device() + if ir.is_triton(device): + lines.extend(debug_triton_code(self)) return textwrap.indent("\n".join(lines).rstrip(), " ") @@ -1110,13 +1229,6 @@ def get_device(self) -> torch.device: def has_aliasing_or_mutation(self) -> bool: return any(x.has_aliasing_or_mutation() for x in self.snodes) - @cache_on_self - def op_counts(self) -> Counter[str]: - op_counts: Counter[str] = collections.Counter() - for node in self.snodes: - op_counts.update(node.op_counts()) - return op_counts - # None of these need to be implemented, as a FusedSchedulerNode is just an # abstraction for scheduling purposes def update_mutated_names(self, renames: Dict[str, str]) -> None: @@ -1492,7 +1604,7 @@ def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode: return grouped_snode def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: - # NB: No need to call super().__init__() because we don't need to re-use any of its logic. + super().__init__(scheduler) init_group_node(self, scheduler, snodes) def unpack(self) -> List[BaseSchedulerNode]: @@ -1692,6 +1804,7 @@ def _init(self, nodes: List[ir.Operation]) -> None: if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) self.nodes = self.fuse_nodes(self.nodes) + self.merge_loops() self.finalize_multi_template_buffers() if config.reorder_for_compute_comm_overlap: self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) @@ -2122,6 +2235,39 @@ def compute_ancestors(self) -> None: node.min_order = order node.max_order = order + def merge_loops(self) -> None: + for node in self.nodes: + if not config.loop_ordering_after_fusion: + continue + + # Even for CPU, if we are using the halide backend, we still need + # the merge loops steps below + if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( + node.get_device().type != "cuda" and config.cpu_backend != "halide" + ): + continue + for snode in node.get_nodes(): + # merge loops for the scheduler node + if not isinstance(snode, SchedulerNode) or snode.is_template(): + continue + + snode._body = snode._body.merge_loops() + snode._sizes = snode._body.sizes + + # merge_loops is called after loop reordering. + # We still need retain fake dependencies since codegen the + # estimated amount of memory access rely on them. + snode.refresh_dependencies(normalize=True) + + # Note that for CPU backend, merging loops will change + # snode.group. It's fine for Triton backend. + # But if we simplify update snode.group like this: + # group_fn = self.get_backend(snode.node.get_device()).group_fn + # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) + # There is still an issue due to different snode in a + # FusedSchedulerNode having different merged loops. + # Skip CPU backend for now. + def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: """ Combine eligible nodes into FusedSchedulerNodes. @@ -2664,6 +2810,76 @@ def decide_fusion_fail_reason( return str(reasons) + def has_shared_data_after_reordering_loop( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Right now just greedily reorder the loop of node1 to be compatible with node2, + but ideally we should have some heuristics to reorder the loop for node2 + to be compatibile with node1 if that's more efficient. + """ + + # TODO Don't do loop reordering for CPU for now. + # Should debug more why it does not work for CPU codegen + if not config.loop_ordering_after_fusion or any( + n.get_device().type == "cpu" for n in [node1, node2] + ): + return False + + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + # Fast path: no common buffers. + common_buffer_names = node1_buffer_names & node2_buffer_names + if not common_buffer_names: + return False + + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + # Find the commons buffers that has different loop orders + candidates = [] + for buffer_name in common_buffer_names: + lhs_dep = node1_name2dep[buffer_name] + rhs_dep = node2_name2dep[buffer_name] + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + candidates.append( + ( + V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), + lhs_dep, + rhs_dep, + ) + ) + + if len(candidates) == 0: + return False + + # Pick the largest buffer to guide the loop reordering + numel, lhs_dep, rhs_dep = max(candidates, key=lambda x: x[0]) + + if lhs_dep.num_vars != rhs_dep.num_vars: + # this can happen due to we don't merge loops. + # We can not do loop reordering in this case right now + # Simply returning true if the two Deps are the same after + # normalization (merging loops) + return lhs_dep.normalize() == rhs_dep.normalize() + + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) + else: + loop_ordering_log.debug( + "Don't reorder loops since both nodes are reductions: %s v.s. %s", + node1.get_name(), + node2.get_name(), + ) + + return self.score_fusion_memory(node1, node2) > 0 + def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: """ Determine if it is possible to combine node1 and node2 into a @@ -2722,6 +2938,17 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: del device2 no_shared_data = self.score_fusion_memory(node1, node2) == 0 + if no_shared_data: + no_shared_data = not self.has_shared_data_after_reordering_loop( + node1, node2 + ) + + loop_ordering_log.debug( + "%s and %s has%s shared data", + node1.get_name(), + node2.get_name(), + " no" if no_shared_data else "", + ) if no_shared_data and ( not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() ): @@ -2780,24 +3007,35 @@ def can_fuse_vertical( be scheduled before the fusion of node1 and node2. """ node1_buf_names = node1.get_buffer_names() - node1_op_names = node1.get_operation_names() - computed_deps: OrderedSet[Dep] = OrderedSet() why = WhyNoFuse(node1, node2) + remaining_deps_by_name: Dict[str, List[Dep]] = defaultdict(list) + + for dep in node2.unmet_dependencies: + name = self.mutation_renames.get(dep.name, dep.name) + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + continue + remaining_deps_by_name[name].append(dep) for cd in node1.read_writes.writes: if not isinstance(cd, MemoryDep): continue - for rd in node2.unmet_dependencies: - if self.fusable_read_and_write(rd, cd): - computed_deps.add(rd) - - for dep in node2.unmet_dependencies: - if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): - computed_deps.add(dep) + remaining = remaining_deps_by_name.get( + self.mutation_renames.get(cd.name, cd.name) + ) + if remaining: + for rd in remaining: + if self.fusable_read_and_write(rd, cd): + remaining.remove(rd) remaining_deps = OrderedSet( - dep.name for dep in node2.unmet_dependencies - computed_deps + [ + dep.name + for dep in itertools.chain.from_iterable( + remaining_deps_by_name.values() + ) + ] ) + if remaining_deps & node1_buf_names: # MemoryDeps didn't match and read different locations of the same buffer. # Examples here include: @@ -2805,6 +3043,8 @@ def can_fuse_vertical( # - MemoryDep("foo", x) != StarDep("foo") why("memory deps did not match") return False + + node1_op_names = node1.get_operation_names() for name in remaining_deps: op_name = self.name_to_buf[name].defining_op.get_name() if node1_op_names & self.name_to_fused_node[op_name].ancestors: @@ -2852,16 +3092,24 @@ def fusable_weak_dep( # if there's indirect indexing, don't match it def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: if isinstance(read, MemoryDep): - if read.mode == write.mode and write.mode is not None: - return True - read_name = read.name - if read_name in self.mutation_renames: - read_name = self.mutation_renames[read_name] + read_name = self.mutation_renames.get(read.name, read.name) + + if ( + read_name != write.name + or free_symbol_is_type(read.index, SymT.TMP) + or free_symbol_is_type(write.index, SymT.TMP) + ): + return False + + if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: + # Need merge loops if we do loop ordering after fusion since + # we have not merged the loops yet when creating the scheduler + # nodes. + read = read.normalize() + write = write.normalize() + return ( - read_name == write.name - and not free_symbol_is_type(read.index, SymT.TMP) - and not free_symbol_is_type(write.index, SymT.TMP) - and read.index == write.index + read.index == write.index and len(read.size) >= len(write.size) and read.size[: len(write.size)] == write.size ) @@ -3149,18 +3397,39 @@ def codegen(self) -> None: return self._codegen() def _codegen(self) -> None: - for node in self.nodes: - try: - log.debug( - "Generating code for node %s with estimated runtime %f", - node.get_name(), - node.get_estimated_runtime(), - ) - except Exception as e: - log.debug( - "Generating code for node %s with estimated runtime 0.0", - node.get_name(), + if config.check_stack_no_cycles_TESTING_ONLY: + import torch._dynamo.convert_frame + + stack = traceback.extract_stack() + seen = set() + for frame in reversed(stack): + # This is where maybe_cprofile is + if ( + frame.name == "_compile_inner" + and frame.filename == torch._dynamo.convert_frame.__file__ + ): + break + key = (frame.filename, frame.lineno) + assert key not in seen, ( + f"Duplicate stack frame {frame.filename}:{frame.lineno}; " + "did you add a decorator to one of the functions in this stack " + "trace? If so, try using a context manager instead." ) + seen.add(key) + + for node in self.nodes: + if log.isEnabledFor(logging.DEBUG): + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception as e: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) self.enter_context(node) @@ -3300,7 +3569,7 @@ def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: raise # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. - small_kernel = ms2_clone / ms2 > 0.6 and ms2 - ms2_clone < 0.2 + small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 if fusion_log.isEnabledFor(logging.DEBUG): if ms1 > ms2 or small_kernel: fusion_log.debug( @@ -3312,8 +3581,8 @@ def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: "cannot fuse (benchmark): fusing causes %sx slowdown", red_text(f"{ms1 / ms2:.3f}"), ) - - return ms2 < ms1 or small_kernel + # ms1 returned by benchmark_fused_nodes discounted clone time + return ms2 - ms2_clone < ms1 or small_kernel def get_buffer_layout(self, buf_name: str) -> ir.Layout: buf = self.name_to_buf[buf_name] diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5b585753bd12e2..7b90690cf50b19 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -203,7 +203,7 @@ def estimate_kernel_num_bytes(self): num_bytes = [] for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): size = V.graph.sizevars.size_hints(inp.get_size()) - numel = functools.reduce(operator.mul, size) + numel = functools.reduce(operator.mul, size, 1) dtype_size = get_dtype_size(inp.get_dtype()) num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) return sum(num_bytes) @@ -1258,6 +1258,11 @@ def no_op(*args, **kwargs): if timings: return no_op + if config.search_autotune_cache and not ( + config.max_autotune or config.max_autotune_gemm + ): + return no_op + precompile_key = create_precompile_key(name, inputs_key, choices) if precompile_func := self.precompile_cache.get(precompile_key): return precompile_func diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3c76bbe432da92..1004817d6e7dbc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -464,13 +464,23 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: - key = f"__{fn.__name__}_cache" - - @functools.wraps(fn) - def wrapper(self): - if not hasattr(self, key): - setattr(self, key, fn(self)) - return getattr(self, key) + name = fn.__name__ + key = f"__{name}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def {name}_cache_on_self(self): + try: + return self.{key} + except AttributeError: + self.{key} = rv = fn(self) + return rv + """.lstrip(), + ctx, + ) + wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"]) def clear_cache(self): if hasattr(self, key): @@ -1054,7 +1064,9 @@ def is_big_gpu(index) -> bool: def use_max_autotune() -> bool: - return config.max_autotune or config.max_autotune_gemm + return ( + config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache + ) def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: @@ -1251,10 +1263,14 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2, mat2_transposed=False): num_threads=parallel_num_threads(), ) + def is_last_dim_stride1(x): + x.freeze_layout() + return x.get_stride()[-1] == 1 + return ( layout.dtype in layout_dtypes and micro_gemm is not None - and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input + and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input and isinstance(mat2, ir.StorageBox) and mat2.is_module_buffer() ) @@ -1361,6 +1377,25 @@ def run_and_get_triton_code(fn, *args, **kwargs): return source_codes[0] +def run_and_get_graph_lowering(fn, *args, **kwargs): + from torch._inductor.codecache import CompiledFxGraph + from torch._inductor.graph import GraphLowering + + real_init = CompiledFxGraph.__init__ + graph_lowerings = [] + + def fake_init(*args, **kwargs): + real_init(*args, **kwargs) + graph = args[2] + assert isinstance(graph, GraphLowering) + graph_lowerings.append(graph) + + with mock.patch.object(CompiledFxGraph, "__init__", fake_init): + result = fn(*args, **kwargs) + + return result, graph_lowerings + + @contextlib.contextmanager def override_lowering(aten_op, override_fn): """ diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index b47cd97c1667c0..b00a94c5f2cee3 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -76,7 +76,7 @@ from torch._inductor.codegen.cpp_utils import LocalBufferContext from torch._inductor.debug import DebugContext from torch._inductor.graph import GraphLowering - from torch._inductor.ir import InterpreterShim + from torch._inductor.loop_body import InterpreterShim from torch._subclasses import FakeTensorMode threadlocal = local() diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index bdd1f0fc95b7f9..ddbfbcf19609c8 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -120,9 +120,7 @@ def get_info_str(ms, n_regs, n_spills, shared, prefix=""): f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" ) else: - ms = benchmarker.benchmark_gpu( - lambda: kernel_mod.call(args), rep=40, fast_flush=True - ) + ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) assert ( len(triton_kernel.launchers) == 1 ), "Autotuner should have selected the best config" diff --git a/torch/_lazy/__init__.py b/torch/_lazy/__init__.py index c074abd143723e..8d90efa40e5884 100644 --- a/torch/_lazy/__init__.py +++ b/torch/_lazy/__init__.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import threading import torch._C._lazy from torch.utils._pytree import tree_flatten, tree_unflatten diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index ced50a7892a010..2eb45a78a5edce 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -332,12 +332,16 @@ def backend_impl(*args, **kwargs): fn = self._backend_fns[device_type] module = inspect.getmodule(fn) raise RuntimeError( - f"Tensors returned from custom ops (1) must not " - f"be inputs to the custom op and (2) may not alias " - f"any inputs or other returns. Please clone the " - f"the offending output tensors (e.g. output.clone()) " - f"or refactor your code. " - f"Offending op: {self._name} (with implementation in {module})" + f"{self._name} (with implementation in {module}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." ) storages.add(key) return result diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index b2eeb24521d382..9845a64874803a 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -268,4 +268,4 @@ def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.L elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] return typing.List[type_args[0]] # type: ignore[valid-type] else: - return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc] + return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value] diff --git a/torch/_logging/__init__.py b/torch/_logging/__init__.py index 0531869ae2fdf3..5acf175c275222 100644 --- a/torch/_logging/__init__.py +++ b/torch/_logging/__init__.py @@ -9,6 +9,7 @@ from ._internal import ( _init_logs, DEFAULT_LOGGING, + get_structured_logging_overhead, getArtifactLogger, LazyString, set_logs, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 1cf37a2a57b084..f78396545c1f75 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -6,8 +6,12 @@ import logging import os import os.path +import pathlib import re +import sys import tempfile +import time +from collections import defaultdict from dataclasses import dataclass, field from importlib import __import__ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -322,9 +326,6 @@ def set_logs( aot_joint_graph (:class:`bool`): Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` - inductor (:class:`Optional[int]`): - Whether to log information from inductor cudagraphs. Default: ``logging.WARN`` - ddp_graphs (:class:`bool`): Whether to emit graphs generated by DDPOptimizer. Default: ``False`` @@ -750,6 +751,26 @@ def _has_registered_parent(log_qname): return False +def make_module_path_relative(abs_path): + """ + Given an absolute filepath corresponding to a Python module which was + loaded via normal import mechanisms using sys.path, convert it into + a relative path relative to one of the Python search paths. + """ + + abs_path = pathlib.Path(abs_path).resolve() + + for path in sys.path: + try: + rel_path = abs_path.relative_to(path) + except ValueError: + continue + else: + return str(rel_path) + + return str(abs_path) + + # apply custom formats to artifacts when necessary class TorchLogsFormatter(logging.Formatter): def __init__(self, *, trace: bool = False): @@ -810,9 +831,11 @@ def format(self, record): if artifact_name is not None: record.artifactprefix = f" [__{artifact_name}]" + filepath = make_module_path_relative(record.pathname) + prefix = ( f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.process} " - f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:" + f"{filepath}:" f"{record.lineno}]{record.traceid}{record.artifactprefix}" ) if self._is_trace: @@ -1070,6 +1093,42 @@ def __str__(self): return self.func(*self.args, **self.kwargs) +# Logs the time it takes to do structured logging by frame/compile id +# key is always {frame_id}_{frame_compile_id} +structured_logging_overhead: Dict[str, float] = defaultdict(float) + + +# Same principle as add_remote_cache_time_saved, but do it for structured logging +def add_structured_logging_overhead(time_spent: float) -> None: + global structured_logging_overhead + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + # Why not trace_id.attempt, like structured logging? + # We aggregate across all attempts because + # a compilation metric is logged per successful attempt + key = f"{frame_id}_{frame_compile_id}" + # TODO: deal with structured logging that occurs outside of specific compile ids + # It's hard to figure out where we would log that if we want it in compilation metrics + # itself. + if key is not None: + key = str(key) + structured_logging_overhead[key] += time_spent + + +def get_structured_logging_overhead() -> Optional[float]: + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + key = f"{frame_id}_{frame_compile_id}" + if key is not None: + return structured_logging_overhead.get(key) + else: + return None + + def trace_structured( name: str, # NB: metadata expected to be dict so adding more info is forward compatible @@ -1079,6 +1138,7 @@ def trace_structured( payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, suppress_context: bool = False, expect_trace_id: bool = True, # Whether or not we expect to have a current trace id + record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging ): """ metadata is an arbitrary JSON compatible struct, but it's expected to not be @@ -1097,6 +1157,7 @@ def trace_structured( # trace_log never propagates and is ALWAYS DEBUG, so also check that there # are handlers instead of checking the log level if trace_log.handlers: + start_time = time.time_ns() record: Dict[str, object] = {} record[name] = metadata_fn() if not suppress_context: @@ -1135,6 +1196,11 @@ def trace_structured( ) log_trace_structured_event(name, record) + if record_logging_overhead: + # Convert to seconds from nanoseconds, add it to the frame compile total + structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9 + add_structured_logging_overhead(structured_logging_overhead_s) + import torch._guards import torch._utils_internal diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 69af10f4a4a5d5..e2bdedf349a0ac 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -42,6 +42,7 @@ [ "torch._dynamo", "torch.export", + "torch.export.dynamic_shapes", *DYNAMIC, "torch._export.converter", "torch._export.non_strict_utils", @@ -157,6 +158,11 @@ "Detailed Inductor fusion decisions. More detailed than 'schedule'", off_by_default=True, ) +register_artifact( + "loop_ordering", + "Logs related to loop ordering", + off_by_default=True, +) register_artifact( "overlap", "Detailed Inductor compute/comm overlap decisions", diff --git a/torch/_logging/scribe.py b/torch/_logging/scribe.py new file mode 100644 index 00000000000000..18745e468b5e6a --- /dev/null +++ b/torch/_logging/scribe.py @@ -0,0 +1,61 @@ +from typing import Callable, List, Union +from typing_extensions import TypeAlias + + +try: + from fbscribelogger import make_scribe_logger # type: ignore[import-untyped] +except ImportError: + TAtom: TypeAlias = Union[int, float, bool, str] + TField: TypeAlias = Union[TAtom, List[TAtom]] + TLazyField: TypeAlias = Union[TField, Callable[[], TField]] + + def make_scribe_logger(name: str, thrift_src: str) -> Callable[..., None]: + def inner(**kwargs: TLazyField) -> None: + pass + + return inner + + +open_source_signpost = make_scribe_logger( + "TorchOpenSourceSignpost", + """ +struct TorchOpenSourceSignpostLogEntry { + + # The commit SHA that triggered the workflow, e.g., 02a6b1d30f338206a71d0b75bfa09d85fac0028a. Derived from GITHUB_SHA. + 4: optional string commit_sha; + + # Commit date (not author date) of the commit in commit_sha as timestamp, e.g., 1724208105. Increasing if merge bot is used, though not monotonic; duplicates occur when stack is landed. + 5: optional i64 commit_date; + + # The fully-formed ref of the branch or tag that triggered the workflow run, e.g., refs/pull/133891/merge or refs/heads/main. Derived from GITHUB_REF. + 6: optional string github_ref; + + # Indicates if branch protections or rulesets are configured for the ref that triggered the workflow run. Derived from GITHUB_REF_PROTECTED. + 7: optional bool github_ref_protected; + + # A unique number for each attempt of a particular workflow run in a repository, e.g., 1. Derived from GITHUB_RUN_ATTEMPT. + 8: optional string github_run_attempt; + + # A unique number for each workflow run within a repository, e.g., 19471190684. Derived from GITHUB_RUN_ID. + 9: optional string github_run_id; + + # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER. + 10: optional string github_run_number_str; + + # The name of the current job. Derived from JOB_NAME, e.g., linux-jammy-py3.8-gcc11 / test (default, 3, 4, amz2023.linux.2xlarge). + 11: optional string job_name; + + # The GitHub user who triggered the job. Derived from GITHUB_TRIGGERING_ACTOR. + 12: optional string github_triggering_actor; + 13: optional string name; # Event name + 14: optional string parameters; # Parameters (JSON data) + 16: optional string subsystem; # Subsystem the event is associated with + + # The unit timestamp in second for the Scuba Time Column override + 17: optional i64 time; + + # The weight of the record according to current sampling rate + 18: optional i64 weight; +} +""", # noqa: B950 +) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 80641f53b7a72a..b67f28b75cc413 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -16,10 +16,12 @@ from torch._ops import OpOverload from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND from torch._prims_common import ( + BoolLike, corresponding_complex_dtype, corresponding_real_dtype, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, IntLike, make_contiguous_strides_for, Number, @@ -3563,6 +3565,45 @@ def meta_binop_inplace(self, other): ], ) def meta_binop_inplace_alpha(self, other, alpha=1): + """ + Some checks for inplace ops. + Checks for promotion rules for some dtypes. + int.add/sub_(float) and bool.add/sub_(others) are rejected. + Promoting in these in-place operations would require reallocating + and copying over elements, hence not allowed. + Checks for alpha param. + """ + + def is_integeric(arg): + if isinstance(arg, TensorLike): + return utils.is_integer_dtype(arg.dtype) + else: + return isinstance(arg, IntLike) + + def is_floatic(arg): + if isinstance(arg, TensorLike): + return utils.is_float_dtype(arg.dtype) + else: + return isinstance(arg, FloatLike) + + def is_booleanic(arg): + if isinstance(arg, TensorLike): + return utils.is_boolean_dtype(arg.dtype) + else: + return isinstance(arg, BoolLike) + + # Do not allow int+float->int in-place + if is_integeric(self) and is_floatic(other): + raise RuntimeError( + "Promotion of int.add/sub_(float) in in-place ops are not possible due to element size change." + ) + + # Do not allow bool+other->bool in-place + if is_booleanic(self) and not is_booleanic(other): + raise RuntimeError( + "Promotion of book.add/sub_(others) in in-place ops are not possible due to element size change." + ) + if isinstance(other, torch.Tensor): check_inplace_broadcast(self.shape, other.shape) return self @@ -5041,6 +5082,106 @@ def meta_scatter_(self, dim, index, src_or_value, reduce=None): return self +@register_meta([aten._scaled_dot_product_flash_attention]) +def meta__scaled_dot_product_flash_attention( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + batch_size = query.size(0) + num_heads = query.size(1) + max_seqlen_batch_q = query.size(2) + head_dim = query.size(3) + max_seqlen_batch_k = key.size(2) + + query_t = query.transpose(1, 2) + attention = torch.empty_like(query_t).transpose(1, 2) + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device=query.device, + ) + else: + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + + # Note [Seed and Offset]: device for seed and offset below depends on whether we are + # capturing or not, but at the time of tracing we don't know if we + # are going to use cudagraphs or not, so we return meta tensors here + # it's possible we'll need to have some special handling in inductor for sdpa + + return ( + attention, + logsumexp, + None, + None, + max_seqlen_batch_q, + max_seqlen_batch_k, + torch.empty((), dtype=torch.long, device="meta"), + torch.empty((), dtype=torch.long, device="meta"), + debug_mask, + ) + + +@register_meta([aten._scaled_dot_product_cudnn_attention]) +def meta__scaled_dot_product_cudnn_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + compute_log_sumexp: bool, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + scale: Optional[float] = None, +): + B = query.size(0) + H = query.size(1) + S_Q = query.size(2) + S_KV = key.size(2) + D_QK = query.size(-1) + D_V = value.size(-1) + + res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device) + logsum_exp = torch.empty( + (B, H, S_Q), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset] + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return ( + res, + logsum_exp, + None, + None, + S_Q, + S_KV, + seed, + offset, + None, + ) + + @register_meta( [ aten._scaled_dot_product_flash_attention_backward, @@ -5088,11 +5229,7 @@ def meta__scaled_dot_product_flash_attention_for_cpu( max_seqlen_batch_q = query.size(2) head_dim = query.size(3) - attention = torch.empty( - (batch_size, max_seqlen_batch_q, num_heads, head_dim), - dtype=query.dtype, - device=query.device, - ).transpose(1, 2) + attention = torch.empty_like(query) logsumexp = torch.empty( ( batch_size, @@ -5155,6 +5292,46 @@ def meta__scaled_dot_product_flash_attention_for_cpu_backward( return grad_q, grad_k, grad_v +@register_meta([aten._scaled_dot_product_efficient_attention]) +def meta__scaled_dot_product_efficient_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Optional[Tensor], + compute_log_sumexp: bool, + dropout_p=0.0, + is_causal: bool = False, + scale: Optional[float] = None, +): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 + logsum_exp = torch.empty( + (B, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + res = res.transpose(1, 2) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset + + @register_meta( [ aten._scaled_dot_product_efficient_attention_backward, @@ -5244,6 +5421,71 @@ def meta__scaled_dot_product_cudnn_backward( return grad_q, grad_k, grad_v +@register_meta( + [ + aten._flash_attention_forward, + ] +) +def meta__flash_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + cum_seq_q: Optional[Tensor], + cum_seq_k: Optional[Tensor], + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + return_debug_mask: bool, + scale: Optional[float] = None, + window_size_left: Optional[int] = None, + window_size_right: Optional[int] = None, + seqused_k: Optional[Tensor] = None, + alibi_slopes: Optional[Tensor] = None, +): + # NB: there are two underlying paths: + # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) + # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total + # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total + batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 + max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q + max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k + num_heads = query.size(-2) + head_dim = query.size(-1) + + # Cuda Path + attention = torch.empty_like(query) + logsumexp = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q), + dtype=torch.float, + device=query.device, + ) + + if return_debug_mask: + blocksize_c = 128 if head_dim > 64 else 256 + max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) + if max_seqlen_batch_k <= 128: + max_seqlen_k = 128 + elif max_seqlen_batch_k <= 256: + max_seqlen_k = 256 + debug_mask = torch.empty( + (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), + dtype=query.dtype, + device=query.device, + ) + else: + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + + # See Note [Seed and Offset]: + return ( + attention, + logsumexp, + torch.empty((), dtype=torch.long, device="meta"), + torch.empty((), dtype=torch.long, device="meta"), + debug_mask, + ) + + @register_meta( [ aten._flash_attention_backward, @@ -5275,6 +5517,59 @@ def meta__flash_attention_backward( return grad_query, grad_key, grad_value +@register_meta( + [ + aten._efficient_attention_forward, + ] +) +def meta__efficient_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + bias: Optional[Tensor], + cu_seqlens_q: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + dropout_p: float, + custom_mask_type: int, + compute_log_sumexp: bool = False, + scale: Optional[float] = None, + causal_diagonal: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, + window_size: Optional[int] = None, +): + B = query.size(0) + M = query.size(1) + N = key.size(1) + num_heads = query.size(-2) + K = query.size(-1) + Kv = value.size(-1) + + res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device) + + logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B + actual_max_seqlen_q = M + if cu_seqlens_q is not None: + assert max_seqlen_q is not None + actual_max_seqlen_q = max_seqlen_q + actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N + logsumexp_dim = ( + math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 + ) + logsum_exp = torch.empty( + (logsumexp_batch_dim, num_heads, logsumexp_dim), + dtype=torch.float, + device=query.device, + ) + + # See Note [Seed and Offset]: + seed = torch.empty((), dtype=torch.long, device="meta") + offset = torch.empty((), dtype=torch.long, device="meta") + + return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k + + @register_meta( [ aten._efficient_attention_backward, @@ -5346,12 +5641,6 @@ def meta_scaled_mm( out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): - def is_row_major(stride): - return stride[0] > stride[1] and stride[1] == 1 - - def is_col_major(stride): - return stride[0] == 1 and stride[1] > 1 - def is_fp8_type(dtype): return dtype in ( torch.float8_e4m3fn, @@ -5364,68 +5653,77 @@ def is_fp8_type(dtype): self.dim() == 2 and mat2.dim() == 2, lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", ) - torch._check( - is_row_major(self.stride()), - lambda: "self must be row_major", - ) - torch._check( - is_col_major(mat2.stride()), - lambda: "mat2 must be col_major", - ) - torch._check( - self.size(1) % 16 == 0, - lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", - ) - torch._check( - mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, - lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", - ) torch._check( is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) - # determine scaling type and check input dimensions (refer to Blas.cpp op) - torch._check( - scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, - lambda: "Both scale_a and scale_b must be float (fp32) tensors.", - ) - m, k = self.shape - n = mat2.size(1) - if scale_a.numel() == 1 and scale_b.numel() == 1: - # tensorwise scaling - pass - else: - # for non-tensorwise scaling, enforce 2D input tensors + if device_hint(self) == "cuda": + + def is_row_major(stride): + return stride[0] > stride[1] and stride[1] == 1 + + def is_col_major(stride): + return stride[0] == 1 and stride[1] > 1 + + torch._check( + is_row_major(self.stride()), + lambda: "self must be row_major", + ) + torch._check( + is_col_major(mat2.stride()), + lambda: "mat2 must be col_major", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) torch._check( - scale_a.dim() == 2 and scale_b.dim() == 2, - lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}", + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}", ) - if ( - scale_a.size(0) == m - and scale_a.size(1) == 1 - and scale_b.size(0) == 1 - and scale_b.size(1) == n - ): - # rowwise scaling - torch._check( - scale_a.is_contiguous() and scale_b.is_contiguous(), - lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.", - ) + # determine scaling type and check input dimensions (refer to Blas.cpp op) + torch._check( + scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, + lambda: "Both scale_a and scale_b must be float (fp32) tensors.", + ) + m, k = self.shape + n = mat2.size(1) + if scale_a.numel() == 1 and scale_b.numel() == 1: + # tensorwise scaling + pass else: - # does not match any valid scaling type + # for non-tensorwise scaling, enforce 2D input tensors torch._check( - False, - lambda: ( - "Invalid scaling configuration. " - "For tensorwise scaling, both scales should be scalar. " - f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " - f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " - f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" - ), + scale_a.dim() == 2 and scale_b.dim() == 2, + lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}", ) + if ( + scale_a.size(0) == m + and scale_a.size(1) == 1 + and scale_b.size(0) == 1 + and scale_b.size(1) == n + ): + # rowwise scaling + torch._check( + scale_a.is_contiguous() and scale_b.is_contiguous(), + lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.", + ) + else: + # does not match any valid scaling type + torch._check( + False, + lambda: ( + "Invalid scaling configuration. " + "For tensorwise scaling, both scales should be scalar. " + f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " + f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " + f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" + ), + ) + _out_dtype = out_dtype if out_dtype is not None else self.dtype return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device) @@ -6086,7 +6384,7 @@ def meta_searchsorted( def _check_for_unsupported_isin_dtype(dtype): torch._check( - dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64], + dtype not in (torch.bool, torch.complex128, torch.complex64), lambda: f"Unsupported input type encountered for isin(): {dtype}", ) @@ -6234,28 +6532,6 @@ def meta__jagged_to_padded_dense_forward( return values.new_empty(output_shape) -@register_meta(aten._padded_dense_to_jagged_forward.default) -def meta__padded_dense_to_jagged_forward( - padded: Tensor, - offsets: List[Tensor], - total_L: Optional[int] = None, -): - # only one jagged dim is supported for now - assert len(offsets) == 1 - - if not total_L: - assert isinstance(padded, torch._subclasses.FakeTensor) - shape_env = padded.fake_mode.shape_env - assert shape_env is not None - total_L = shape_env.create_unbacked_symint() - torch.fx.experimental.symbolic_shapes._constrain_range_for_size( - total_L, min=0, max=None - ) - - output_shape = (total_L, *padded.shape[2:]) - return padded.new_empty(output_shape) - - def _create_unary_float_meta_func(func): @register_meta(func) @out_wrapper() diff --git a/torch/_ops.py b/torch/_ops.py index 08d0ffab5e7d70..03ed9ca9260998 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import abc import contextlib import ctypes import importlib @@ -238,7 +239,7 @@ def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] ] -class HigherOrderOperator(OperatorBase): +class HigherOrderOperator(OperatorBase, abc.ABC): # The HigherOrderOperator will appear as torch.ops.higher_order.{name} # # If you're creating a new HigherOrderOperator, please do not change the @@ -410,6 +411,7 @@ def check_overloaded(arg): assert not isinstance(kernel, DispatchKey) return kernel(*args, **kwargs) + @abc.abstractmethod def __call__(self, /, *args, **kwargs): # Dynamo already traces the body of HigherOrderOp beforehand when it # so no need to trace into it. @@ -433,9 +435,6 @@ def wrapper(): def __str__(self): return f"{self.name()}" - # def __repr__(self): - # return f"torch.ops._higher_order_ops.{self._name}" - def name(self): return self._name diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 977d664754982e..26b3f298e36187 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1,16 +1,13 @@ # mypy: allow-untyped-defs -import contextlib -import itertools import operator -import weakref from enum import Enum from functools import partial, reduce -from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union +from typing import Callable, List, Optional, Sequence, Tuple, Type, Union import torch import torch._prims_common as utils import torch.library -from torch import sym_float, Tensor, TypedStorage +from torch import sym_float, Tensor from torch._C import _get_default_device from torch._higher_order_ops.effects import new_token_tensor from torch._library.utils import is_functional_schema diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index f94a0d06ad6b4b..bbbdb8958f9adb 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -142,6 +142,10 @@ def get_device(args, kwargs): devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)} if any(dev == "cuda" for dev in devices): return "cuda" + elif any(dev == "xpu" for dev in devices): + return "xpu" + elif any(dev == "hpu" for dev in devices): + return "hpu" elif any(dev == "cpu" for dev in devices): return "cpu" return None @@ -152,6 +156,9 @@ class RunAndSaveRngState(HigherOrderOperator): def __init__(self): super().__init__("run_and_save_rng_state") + def __call__(self, op, *args, **kwargs): + return super().__call__(op, *args, **kwargs) + run_and_save_rng_state = RunAndSaveRngState() run_and_save_rng_state.py_impl(DispatchKey.Autograd)( @@ -166,9 +173,24 @@ def impl_cuda(op, *args, **kwargs): def impl_cpu(op, *args, **kwargs): return torch.get_rng_state(), op(*args, **kwargs) + @run_and_save_rng_state.py_impl(DispatchKey.HPU) + def impl_hpu(op, *args, **kwargs): + if hasattr(torch, "hpu"): + return torch.hpu.get_rng_state(), op(*args, **kwargs) + raise RuntimeError("functionalize a hpu RNG operator is not supported.") + + @run_and_save_rng_state.py_impl(DispatchKey.XPU) + def impl_xpu(op, *args, **kwargs): + return torch.xpu.get_rng_state(), op(*args, **kwargs) + @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect) def impl_backend_select(op, *args, **kwargs): - impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + impl_map = { + "cuda": impl_cuda, + "cpu": impl_cpu, + "hpu": impl_hpu, + "xpu": impl_xpu, + } device = get_device(args, kwargs) assert device in impl_map, f"Backend not supported for {device}" impl = impl_map[device] @@ -198,6 +220,9 @@ class RunWithRngState(HigherOrderOperator): def __init__(self): super().__init__("run_with_rng_state") + def __call__(self, rng_state, op, *args, **kwargs): + return super().__call__(rng_state, op, *args, **kwargs) + run_with_rng_state = RunWithRngState() run_with_rng_state.py_impl(DispatchKey.Autograd)( @@ -220,6 +245,24 @@ def impl_cpu(rng_state, op, *args, **kwargs): torch.set_rng_state(current_state) return out + @run_with_rng_state.py_impl(DispatchKey.HPU) + def impl_hpu(rng_state, op, *args, **kwargs): + if hasattr(torch, "hpu"): + current_state = torch.hpu.get_rng_state() + torch.hpu.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.hpu.set_rng_state(current_state) + return out + raise RuntimeError("functionalize a hpu RNG operator is not supported.") + + @run_with_rng_state.py_impl(DispatchKey.XPU) + def impl_xpu(rng_state, op, *args, **kwargs): + current_state = torch.xpu.get_rng_state() + torch.xpu.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.xpu.set_rng_state(current_state) + return out + @run_with_rng_state.py_impl(ProxyTorchDispatchMode) def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): # TODO: you don't need to do this, the dispatch here already disabled @@ -235,7 +278,12 @@ def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): @run_with_rng_state.py_impl(DispatchKey.BackendSelect) def impl_backend_select(rng_state, op, *args, **kwargs): - impl_map = {"cuda": impl_cuda, "cpu": impl_cpu} + impl_map = { + "cuda": impl_cuda, + "cpu": impl_cpu, + "hpu": impl_hpu, + "xpu": impl_xpu, + } device = get_device(args, kwargs) assert device in impl_map, f"Backend not supported for {device}" impl = impl_map[device] diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 065ec2eff2384d..61d0ba13b88f15 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -3,10 +3,9 @@ import operator import warnings -import weakref from contextlib import nullcontext from enum import Enum -from functools import cmp_to_key, reduce +from functools import reduce from typing import ( Any, Callable, diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 043322dc9d7448..a89ea7cb9997e6 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -185,11 +185,15 @@ def _maybe_resize_out( return out +def is_cpu_scalar(x: TensorLikeType) -> bool: + return x.dim() == 0 and x.device.type == "cpu" + + def _safe_copy_out( *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False ): # Checks same device - if copy_from.device != copy_to.device: + if not is_cpu_scalar(copy_from) and copy_from.device != copy_to.device: msg = ( f"Attempting to copy from device {copy_from.device} " f"to device {copy_to.device}, but cross-device copies are not allowed!" diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 160a3da6a48909..9b4d10cd5a6ad4 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -290,6 +290,7 @@ "take_along_dim", "tensor_split", "transpose", + "transpose_copy", "unfold", "unfold_copy", "unsqueeze", @@ -804,7 +805,7 @@ def logsumexp( dim = (dim,) if self.numel() == 0: return torch.sum(torch.exp(self), dim, keepdim).log() - maxes = torch.amax(self, dim, keepdim=True) + maxes = torch.amax(torch.real(self), dim, keepdim=True) maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) result = torch.sum(torch.exp(self - maxes), dim, keepdim) @@ -3549,12 +3550,6 @@ def istft( y = y.narrow(dim=1, start=start, length=length) window_envelop = window_envelop.narrow(dim=1, start=start, length=length) - window_envelop_lowest = window_envelop.abs().min().lt(1e-11) - torch._check( - not window_envelop_lowest.item(), - lambda: "window overlap add min less than 1e-11", - ) - y = y / window_envelop if original_ndim == 2: y = y.squeeze(0) @@ -5917,7 +5912,7 @@ def triu_indices( @register_decomposition(aten.bucketize) @out_wrapper(exact_dtype=True) def bucketize( - a: TensorLikeType, + a: TensorOrNumberLikeType, boundaries: TensorLikeType, *, out_int32: bool = False, @@ -5928,6 +5923,7 @@ def bucketize( lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", ) + a = a if isinstance(a, torch.Tensor) else torch.tensor(a) out_dtype = torch.int32 if out_int32 else torch.int64 n_boundaries = boundaries.shape[-1] if n_boundaries == 0: @@ -6340,6 +6336,7 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) t_copy = _make_copy_from_view(aten.t) +transpose_copy = _make_copy_from_view(aten.transpose) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) view_copy = _make_copy_from_view(aten.view) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 77a635ea9ef69b..6585f57e3d643b 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from functools import partial -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch._prims as prims @@ -15,7 +15,6 @@ DimsType, ELEMENTWISE_TYPE_PROMOTION_KIND, IntLike, - NumberType, TensorLikeType, ) from torch._prims_common.wrappers import ( diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py index 197fcad97165e1..13132188a1930e 100644 --- a/torch/_strobelight/compile_time_profiler.py +++ b/torch/_strobelight/compile_time_profiler.py @@ -1,7 +1,9 @@ # mypy: disallow-untyped-defs +import json import logging import os +import subprocess from datetime import datetime from socket import gethostname from typing import Any, Optional @@ -22,6 +24,60 @@ logger.propagate = False +def get_fburl(url: str) -> str: + short_url = url + # Attempt to shorten the URL + try: + result = subprocess.run( + ["fburl", url], capture_output=True, stdin=subprocess.DEVNULL + ) + if result.returncode == 0: + short_url = result.stdout.decode("utf-8") + except Exception as e: + logger.warning("URL shortening failed: %s, using long URL", repr(e)) + return short_url + + +def get_strobelight_url(identifier: str) -> str: + scuba_json = { + "aggregateList": [], + "aggregation_field": "async_stack_complete", + "b_constraints": [[]], + "c_constraints": [[]], + "cols": ["namespace_id", "namespace_process_id"], + "compare": "none", + "constraints": [ + [{"column": "sample_tags", "op": "all", "value": [f'["{identifier}"]']}] + ], + "derivedCols": [], + "end": "now", + "enumCols": [], + "filterMode": "DEFAULT", + "hideEmptyColumns": "false", + "ignoreGroupByInComparison": "false", + "is_timeseries": "false", + "mappedCols": [], + "metric": "count", + "modifiers": [], + "order": "weight", + "order_desc": "true", + "param_dimensions": [ + {"dim": "py_async_stack", "op": "edge", "param": "0", "anchor": "0"} + ], + "purposes": [], + "return_remainder": "false", + "samplingRatio": "1", + "should_pivot": "false", + "start": "-30 days", + "timezone": "America/Los_Angeles", + "top": 10000, + } + scuba_url_prefix = "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experimental/on_demand&drillstate=" + scuba_url_suff = "&view=GraphProfilerView&&normalized=1726332703&pool=uber" + long_url = scuba_url_prefix + json.dumps(scuba_json) + scuba_url_suff + return get_fburl(long_url) + + class StrobelightCompileTimeProfiler: success_profile_count: int = 0 failed_profile_count: int = 0 @@ -89,27 +145,8 @@ def _cls_init(cls) -> None: logger.info("Unique sample tag for this run is: %s", cls.identifier) logger.info( - "You can use the following link to access the strobelight profile at the end of the run: %s", - ( - "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experime" - "ntal%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22no" - "w%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%" - "22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22n" - "amespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%" - "2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22" - "%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C" - "%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Ang" - "eles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2" - "C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack" - "_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param" - "_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22" - "edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22orde" - "r%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[[" - "%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%" - f"22%3A[%22[%5C%22{cls.identifier}%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b" - "_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&vi" - "ew=GraphProfilerView&&normalized=1712358002&pool=uber" - ), + "URL to access the strobelight profile at the end of the run: %s", + get_strobelight_url(cls.identifier), ) @classmethod diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 8c6a8521d56719..b80b200a3c52ba 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -373,6 +373,7 @@ def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): return repeats.new_empty(output_size) +@register_op_impl(torch.ops.aten.item.default) @register_op_impl(torch.ops.aten._local_scalar_dense.default) def local_scalar_dense(fake_mode, func, arg): if (r := arg.item_memo) is not None: @@ -436,6 +437,38 @@ def nonzero(fake_mode, func, arg): return arg.new_empty((nnz, arg.dim()), dtype=torch.int64) +@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default) +def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=None): + # only one jagged dim is supported for now + assert len(offsets) == 1 + + if not total_L: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + total_L = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(padded.numel()): + maxval = int(padded.numel()) + + _constrain_range_for_size(total_L, min=0, max=maxval) + + output_shape = (total_L, *padded.shape[2:]) + return padded.new_empty(output_shape) + + @register_op_impl(torch.ops.aten.masked_select.default) def masked_select(fake_mode, func, self, mask): if ( @@ -455,10 +488,23 @@ def masked_select(fake_mode, func, self, mask): _constrain_range_for_size, has_free_symbols, ) + from torch.utils._sympy.numbers import IntInfinity + from torch.utils._sympy.value_ranges import bound_sympy + # If num elements is expressed symbolically, calculate + # the concrete value based on upper bounds. Otherwise, + # we can set max val directly. if not has_free_symbols(self.numel()): - if self.numel() > 2: - maxval = int(self.numel()) + num_elements = int(self.numel()) + else: + prod_node = math.prod(self.shape).node + prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range) + if isinstance(prod_range.upper, IntInfinity): + num_elements = sys.maxsize - 1 + else: + num_elements = prod_range.upper + if num_elements > 2: + maxval = num_elements _constrain_range_for_size(nnz, max=maxval) @@ -704,318 +750,6 @@ def convert(t, mem_fmt): ) -@register_op_impl(aten._scaled_dot_product_flash_attention.default) -def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - return_debug_mask = kwargs["return_debug_mask"] - # unused: value, dropout_p, is_causal, scale - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - batch_size = query.size(0) - num_heads = query.size(1) - max_seqlen_batch_q = query.size(2) - head_dim = query.size(3) - max_seqlen_batch_k = key.size(2) - - query_t = query.transpose(1, 2) - # empty_like already returns a fake tensor so we don't need to convert it - attention = torch.empty_like(query_t).transpose(1, 2) - logsumexp = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device="meta", - ), - device=query.device, - ) - - if return_debug_mask: - blocksize_c = 128 if head_dim > 64 else 256 - max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) - if max_seqlen_batch_k <= 128: - max_seqlen_k = 128 - elif max_seqlen_batch_k <= 256: - max_seqlen_k = 256 - debug_mask = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), - dtype=query.dtype, - device="meta", - ), - device=query.device, - ) - else: - debug_mask = convert_tensor( - torch.empty(0, dtype=query.dtype, device="meta"), - query.device, - ) - - # Note [Seed and Offset]: device for seed and offset below depends on whether we are - # capturing or not, but at the time of tracing we don't know if we - # are going to use cudagraphs or not, so we return meta tensors here - # it's possible we'll need to have some special handling in inductor for sdpa - - return ( - attention, - logsumexp, - None, - None, - max_seqlen_batch_q, - max_seqlen_batch_k, - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - debug_mask, - ) - - -@register_op_impl(aten._scaled_dot_product_cudnn_attention.default) -def meta__scaled_dot_product_cudnn(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: attn_bias, dropout_p, is_causal, return_debug_mask, scale - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - B = query.size(0) - H = query.size(1) - S_Q = query.size(2) - S_KV = key.size(2) - D_QK = query.size(-1) - D_V = value.size(-1) - - res = convert_tensor( - torch.empty(B, H, S_Q, D_V, dtype=query.dtype, device="meta"), - query.device, - ) - - logsum_exp = convert_tensor( - torch.empty( - (B, H, S_Q), - dtype=torch.float, - device="meta", - ), - query.device, - ) - - # See Note [Seed and Offset]: - seed = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - offset = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - return ( - res, - logsum_exp, - None, - None, - S_Q, - S_KV, - seed, - offset, - None, - ) - - -@register_op_impl(aten._scaled_dot_product_efficient_attention.default) -def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: attn_bias, dropout_p, is_causal, scale - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - B = query.size(0) - M = query.size(1) - N = key.size(1) - num_heads = query.size(-2) - K = query.size(-1) - Kv = value.size(-1) - - res = convert_tensor( - torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), - query.device, - ) - - logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 - logsum_exp = convert_tensor( - torch.empty( - (B, num_heads, logsumexp_dim), - dtype=torch.float, - device="meta", - ), - query.device, - ) - - res = res.transpose(1, 2) - - # See Note [Seed and Offset]: - seed = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - offset = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - - return res, logsum_exp, seed, offset - - -@register_op_impl(aten._flash_attention_forward.default) -def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - cum_seq_q = kwargs["cum_seq_q"] - cum_seq_k = kwargs["cum_seq_k"] - max_q = kwargs["max_q"] - max_k = kwargs["max_k"] - return_debug_mask = kwargs["return_debug_mask"] - # unused: value, dropout_p, is_causal, scale - # unused: seqused_k, alibi_slopes, window_size_left, window_size_right - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - # NB: there are two underlying paths: - # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) - # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total - # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total - batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 - max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q - max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k - num_heads = query.size(-2) - head_dim = query.size(-1) - - # Cuda Path - # note: empty_like already returns a fake tensor, we don't need to wrap it - attention = torch.empty_like(query) - logsumexp = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q), - dtype=torch.float, - device="meta", - ), - device=query.device, - ) - - if return_debug_mask: - blocksize_c = 128 if head_dim > 64 else 256 - max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) - if max_seqlen_batch_k <= 128: - max_seqlen_k = 128 - elif max_seqlen_batch_k <= 256: - max_seqlen_k = 256 - debug_mask = convert_tensor( - torch.empty( - (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), - dtype=query.dtype, - device="meta", - ), - query.device, - ) - else: - debug_mask = convert_tensor( - torch.empty(0, dtype=query.dtype, device="meta"), - query.device, - ) - - # See Note [Seed and Offset]: - return ( - attention, - logsumexp, - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), - debug_mask, - ) - - -@register_op_impl(aten._efficient_attention_forward.default) -def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): - _, kwargs = normalize_function( - func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True - ) - - query = kwargs["query"] - key = kwargs["key"] - value = kwargs["value"] - cu_seqlens_q = kwargs["cu_seqlens_q"] - max_seqlen_q = kwargs["max_seqlen_q"] - max_seqlen_k = kwargs["max_seqlen_k"] - compute_log_sumexp = kwargs["compute_log_sumexp"] - # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, seqlen_k - - def convert_tensor(t, device): - return FakeTensor(fake_mode, t, device) - - B = query.size(0) - M = query.size(1) - N = key.size(1) - num_heads = query.size(-2) - K = query.size(-1) - Kv = value.size(-1) - - res = convert_tensor( - torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), - query.device, - ) - - logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B - actual_max_seqlen_q = M - if cu_seqlens_q is not None: - assert max_seqlen_q is not None - actual_max_seqlen_q = max_seqlen_q - actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N - logsumexp_dim = ( - math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 - ) - logsum_exp = convert_tensor( - torch.empty( - (logsumexp_batch_dim, num_heads, logsumexp_dim), - dtype=torch.float, - device="meta", - ), - query.device, - ) - - # See Note [Seed and Offset]: - seed = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - offset = convert_tensor( - torch.empty((), dtype=torch.long, device="meta"), query.device - ) - - return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k - - @register_op_impl(torch.ops.aten._pack_padded_sequence.default) def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first): if ( diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 0450d4d19b5fb0..17b31a8e19ba3f 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -14,6 +14,7 @@ from collections import defaultdict from dataclasses import dataclass from typing import ( + Any, Callable, cast, Dict, @@ -893,28 +894,14 @@ def get_nested_int( ) return self.nested_int_memo * coeff - # We must handle tolist in a special way for FakeTensors here in the case - # where tolist is called from torch dispatch for tensor subclasses. - # Ordinarily, if a program calls .tolist compiling still works because there is - # special handling in dynamo, but for tensor subclasses if .tolist is called - # inside torch dispatch, the .tolist call may be directly on a FakeTensor. - # This would result in an error since wrapper subclasses don't have storage. - # To avoid this, we handle the FakeTensor case by (1) specializing on the size - # of the tensor to create the output Python list, and (2) creating unbacked - # symints for each element of the list. - def tolist(self) -> List[SymInt]: - assert self.dim() == 1, "NYI for higher dims" - shape_env = self.fake_mode.shape_env - assert shape_env is not None - out = [] - # Specialize on the length of the list - for _ in range(self.shape[0]): - s = shape_env.create_unbacked_symint() - # max value? - torch._check_is_size(s) - torch._check(s >= 2) - out.append(s) - return out + # Similar to FunctionalTensor.tolist + def tolist(self) -> Any: + if self.dim() == 0: + return self.item() + elif self.dim() == 1: + return [elem.item() for elem in self] + else: + return [elem.tolist() for elem in self] _MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"] @@ -1715,6 +1702,24 @@ def _dispatch_impl( ) -> Optional[FakeTensor]: flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING + # We must throw NotImplemented in case of unrecognized types to handle subclasses. + # Throwing the exception will pass the control to the next __torch_dispatch__. + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] has_symbolic_sizes = any( i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors @@ -1755,21 +1760,6 @@ def _dispatch_impl( out = out.clone() return converter.from_real_tensor(self, out, make_constant=True) - # See [subclass inputs] below - # NB: If you're seeing a mysterious infinite loop involving fake - # tensor, it might be related to this line. Though I'm not sure - # how you'll know to read this comment, as this line won't show up - # in the stack trace. - has_unrecognized_types = _check_for_subclass(flat_args) - if has_unrecognized_types: - unrecognized_types = [ - type(x) for x in flat_args if _check_for_subclass_arg(x) - ] - not_implemented_log.debug( - "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types - ) - return NotImplemented - # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. @@ -1887,6 +1877,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: for a in flat_args ) ): + log.debug("propagate_real_tensors %s", func) real_flat_args = [maybe_to_real_tensor(a) for a in flat_args] real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec) real_out = func(*real_args, **real_kwargs) @@ -1898,7 +1889,7 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: # However, if there's a bug in the condition above, this condition # will also trigger. log.debug( - "propagate_real_tensors skipped %s(%s, %s) %s", + "SKIPPED propagate_real_tensors %s(%s, %s) %s", func, flat_arg_fake_tensors, flat_args, @@ -1908,17 +1899,40 @@ def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: def maybe_propagate_real_tensors(fake_out: T) -> T: import sympy + log.debug("maybe_propagate_real_tensors %s", func) + def go(t: object, real_t: Tensor) -> None: if isinstance(t, FakeTensor): # NB: unconditionally overwrite + log.debug( + "maybe_propagate_real_tensors %s -> %s", id(t), id(real_t) + ) t.real_tensor = real_t + for s, real_s in zip(t.size(), real_t.size()): + go(s, real_s) # type: ignore[arg-type] + for s, real_s in zip(t.stride(), real_t.stride()): + go(s, real_s) # type: ignore[arg-type] + go(t.storage_offset(), real_t.storage_offset()) # type: ignore[arg-type] elif isinstance(t, py_sym_types) and free_unbacked_symbols(t): if isinstance(t.node.expr, sympy.Symbol): assert self.shape_env is not None self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) if real_out is not nil: - tree_map_(go, fake_out, real_out) + if ( + not isinstance(fake_out, Tensor) + and not isinstance(real_out, Tensor) + and type(fake_out) != type(real_out) + ): + # This can happen when decompositions have different return types, + # e.g. namedtuple vs. tuple vs. list. + tree_map_( + go, + tuple(pytree.tree_flatten(fake_out)), + tuple(pytree.tree_flatten(real_out)), + ) + else: + tree_map_(go, fake_out, real_out) # If a data-dependent op is used in a decomposition, we # may need to get the unbacked settings "early" @@ -1950,13 +1964,15 @@ def go(t: object, real_t: Tensor) -> None: ) ): with self: - return decomposition_table[func](*args, **kwargs) + return maybe_propagate_real_tensors( + decomposition_table[func](*args, **kwargs) + ) with self: # Decomposes CompositeImplicitAutograd ops r = func.decompose(*args, **kwargs) if r is not NotImplemented: - return r + return maybe_propagate_real_tensors(r) # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index d0a2c13b395df3..28fc7a4028917f 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -66,6 +66,12 @@ def is_sdpa_error(func, idx, e): and "Devices" in repr(e) ): return True + if ( + func is aten._scaled_dot_product_cudnn_attention.default + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True return False diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index dd7cba26b1cb2e..cd5bfa655ea0b5 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs import contextlib import warnings +import weakref from abc import ABC, abstractmethod -from typing import Any, Callable, ContextManager, Dict, Optional, Tuple, Union +from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union import torch import torch.utils._pytree as pytree @@ -110,7 +111,10 @@ class FunctionalTensor(torch.Tensor): torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type] ] - def __new__(cls, elem): + # Used by auto_functionalize to determine base of tensors during inference mode. + _inference_mode_base: Optional["FunctionalTensor"] = None + + def __new__(cls, elem, mode): assert torch._is_functional_tensor(elem) # In general, we'd like our functional tensor subclass to only be in charge of functionalization, @@ -141,9 +145,9 @@ def __new__(cls, elem): cls, elem.shape, # sizes elem.stride() if not is_sparse_any(elem) else None, # strides - elem.storage_offset() - if not is_sparse_any(elem) - else None, # storage_offset + ( + elem.storage_offset() if not is_sparse_any(elem) else None + ), # storage_offset None, # memory_format elem.dtype, # dtype elem.layout, # layout @@ -157,6 +161,21 @@ def __new__(cls, elem): ) torch._C._set_throw_on_mutable_data_ptr(out) out.elem = elem + + if ( + torch.is_inference_mode_enabled() + and torch._inductor.config.enable_auto_functionalized_v2 + ): + if out.is_base_tensor(): + out._inference_mode_base = None + # This assumes that the FunctionalTensor.elem does not change its storage after this point. + # Otherwise this would be invalid. + mode._storage_to_base[out.elem.untyped_storage()] = out + else: + out._inference_mode_base = mode._storage_to_base[ + out.elem.untyped_storage() + ] + assert out._inference_mode_base is not None return out def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -202,12 +221,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" ) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"FunctionalTensor({repr(self.elem)})" @staticmethod def to_functional(x): # We will do the wrapping for the user. + assert not torch._is_functional_tensor(x) # The only autograd metadata we care about on the FunctionalTensor is: # - requires_grad (so autograd runs) @@ -225,7 +245,7 @@ def to_functional(x): with functional_mode: torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] - out = FunctionalTensor(x_functional) + out = FunctionalTensor(x_functional, functional_mode) torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] return out @@ -233,6 +253,9 @@ def from_functional(self): torch._sync(self) return torch._from_functional_tensor(self.elem) + def is_base_tensor(self) -> bool: + return torch._is_functional_tensor_base(self.elem) + def replace_(self, output) -> None: torch._functionalize_replace(self.elem, output) @@ -279,7 +302,7 @@ def cuda(self, device=None, *args, **kwargs): long = _conversion_method_template(dtype=torch.int64) # TODO(sparse-team): fixes #133174 but can we do without the relay? - def to_dense(self): + def to_dense(self): # type: ignore[override] return self.elem.to_dense() @property @@ -315,6 +338,10 @@ def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=Fals # discovery. This flag distinguishes between the two stages. self._allow_token_discovery = _allow_token_discovery + self._storage_to_base: weakref.WeakKeyDictionary[ + torch.storage.UntypedStorage, Optional[FunctionalTensor] + ] = weakref.WeakKeyDictionary() + # No-op if FunctionalTensorMode is already in use def __enter__(self): def _get_prev_mode(): @@ -342,12 +369,30 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + if self.export: + # We need to make sure that we don't decompose to() as usual in export mode, + # because it can get optimized away. Instead we always replace it with _to_copy(). + if func == torch.ops.aten.to.dtype_layout: + kwargs.pop("copy", None) + return self.__torch_dispatch__( + torch.ops.aten._to_copy.default, types, args, kwargs + ) + if func == torch.ops.aten.to.dtype: + schema = tuple(arg.name for arg in func._schema.arguments) + for arg, name in zip(args[1:], schema[1:]): + kwargs[name] = arg + kwargs.pop("copy", None) + return self.__torch_dispatch__( + torch.ops.aten._to_copy.default, types, args[:1], kwargs + ) + unrecognized_types = [ t for t in types if not issubclass(t, torch._subclasses.FakeTensor) and t not in [torch.Tensor, FunctionalTensor] ] + if unrecognized_types: not_implemented_log.debug( "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types @@ -399,16 +444,13 @@ def _can_decompose(func): if r is not NotImplemented: return r - def assert_is_functional(x): - assert torch._is_functional_tensor(x) - def wrap(x): # Only wrap our outputs in subclasses if the inner functionalization call # also wrapped outputs into FunctionalTensorWrappers. # When can this happen? e.g. `torch.div(2, 2)` assert not isinstance(x, FunctionalTensor) if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): - return FunctionalTensor(x) + return FunctionalTensor(x, self) return x def unwrap(x): @@ -417,6 +459,7 @@ def unwrap(x): from torch._higher_order_ops.auto_functionalize import ( can_auto_functionalize, do_auto_functionalize, + do_auto_functionalize_v2, ) if can_auto_functionalize( @@ -427,7 +470,12 @@ def unwrap(x): # it doesn't matter what mode we use here because # the implementation of do_auto_functionalize doesn't # interact with FunctionalTensorMode at all - return do_auto_functionalize(func, args, kwargs) + import torch._inductor.config as inductor_config + + if self.export or not inductor_config.enable_auto_functionalized_v2: + return do_auto_functionalize(func, args, kwargs) + else: + return do_auto_functionalize_v2(func, args, kwargs) from torch._higher_order_ops.effects import handle_effects, has_effects @@ -594,7 +642,7 @@ def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: @abstractmethod def unwrap_tensors( self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + ) -> Any: pass @abstractmethod @@ -637,8 +685,8 @@ def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: ) def unwrap_tensors( - self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] - ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]] + ) -> Any: return torch.utils._pytree.tree_map_only( FunctionalTensor, FunctionalTensor.from_functional, args ) @@ -731,9 +779,11 @@ def unwrap_tensors( def functionalize(self, inner_f: Callable) -> Callable: return torch.func.functionalize( inner_f, - remove="mutations_and_views" - if self.interpreter.functionalize_add_back_views() - else "mutations", + remove=( + "mutations_and_views" + if self.interpreter.functionalize_add_back_views() + else "mutations" + ), ) def redispatch_to_next(self) -> ContextManager: diff --git a/torch/_tensor.py b/torch/_tensor.py index 98563aebae9aa7..e22c6a92d92d0f 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,8 +209,19 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) state = torch._utils._get_obj_state(self) - if type(self) is Tensor and not state: + # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has + # some state that cannot be pickled + if ( + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + hasattr(torch, "_subclasses") + and type(self) is torch._subclasses.fake_tensor.FakeTensor + and materialize_fake_tensors + ) or (type(self) is Tensor and not state): # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -251,6 +262,12 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() + + skip_data = torch.serialization._serialization_tls.skip_data + materialize_fake_tensors = ( + torch.serialization._serialization_tls.materialize_fake_tensors + ) + # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -268,6 +285,10 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -280,6 +301,10 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. + if skip_data: + warnings.warn( + "Serializing tensors on the meta device under skip_data context manager is a no-op" + ) arg_meta = ( self.dtype, tuple(self.size()), @@ -288,6 +313,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: + if skip_data: + raise RuntimeError( + "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" + ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -369,6 +398,10 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: + if skip_data: + raise RuntimeError( + "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" + ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -383,14 +416,30 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance( - self, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) + or ( + not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and self.data_ptr() == 0 ) - or self.data_ptr() == 0 + ) + ): + arg_wrapper_subclass = ( + type(self), + self.dtype, + tuple(self.size()), + self.stride(), + self.storage_offset(), + self.layout, + self.device, + self.requires_grad, + ) + return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) + elif ( + type(self) is not torch.Tensor + and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ + and ( + isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and not (skip_data and materialize_fake_tensors) ) ): arg_wrapper_subclass = ( @@ -418,6 +467,16 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] + + # TODO: remove hasattr, it's a hack to support versions of torch that + # don't have _subclasses + if ( + hasattr(torch, "_subclasses") + and isinstance(self, torch._subclasses.fake_tensor.FakeTensor) + and skip_data + ): + storage._fake_device = self.device + args = ( storage, self.storage_offset(), @@ -1294,7 +1353,7 @@ def align_to(self, *names): [name for name in names if not is_ellipsis(name)], ellipsis_idx ) - def unflatten(self, dim, sizes): + def unflatten(self, dim, sizes): # type: ignore[override] r""" unflatten(dim, sizes) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index bba6dc2e307329..8d83872f863802 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2276,7 +2276,7 @@ def merge_dicts(*dicts): r""" cat(tensors, dim=0, *, out=None) -> Tensor -Concatenates the given sequence of :attr:`seq` tensors in the given dimension. +Concatenates the given sequence of tensors in :attr:`tensors` in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size ``(0,)``. @@ -10665,7 +10665,7 @@ def merge_dicts(*dicts): Example:: >>> torch.nansum(torch.tensor([1., float("nan")])) - 1.0 + tensor(1.) >>> a = torch.tensor([[1, 2], [3., float("nan")]]) >>> torch.nansum(a) tensor(6.) diff --git a/torch/_utils.py b/torch/_utils.py index 938392fa971590..f0d38daa811490 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,7 +3,6 @@ import functools import logging import sys -import threading import traceback import warnings from collections import defaultdict @@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] -_thread_local_state = threading.local() - - def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = getattr(_thread_local_state, "map_location", None) + map_location = torch.serialization._serialization_tls.map_location if map_location is None: return device else: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index cb6230618bf04b..f254217452061f 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -132,6 +132,10 @@ def log_trace_structured_event(*args, **kwargs) -> None: pass +def log_cache_bypass(*args, **kwargs) -> None: + pass + + def log_torchscript_usage(api: str, **kwargs): _ = api return @@ -155,6 +159,115 @@ def capture_pre_autograd_graph_using_training_ir() -> bool: return False +class JustKnobsConfig: + """Represents a lazily loaded config + + This is designed to be used to specify a value in a config. + + i.e. foo.bar = JustknobsConfig(name="//foo:bar", env_name="FORCE_FOO_BAR") + + Call .get() in order to access the value + i.e. if foo.bar.get(): + + Note that the value is fetched once, and then not allowed to change. This + means less suprises, at the downside that you may have to restart a job + to pick up an update. + + It can also be set explicitly via set - i.e. + foo.bar = JustknobsConfig(name="//foo:bar") + foo.bar.set(True) + + Note that this does allow for no JK name (so that you can use this to replace old configurations). + """ + + def __init__( + self, *, name: Optional[str] = None, env_name=None, default: bool = True + ): + self.name = name + self.env_name = env_name + self.default = default + self.value: Optional[bool] = None + self.executed_value = None + + def set(self, value: bool): + self.value = value + + def get(self): + if self.executed_value is None: + self.executed_value = justknobs_feature( + self.name, + config_value=self.value, + env_name=self.env_name, + default=self.default, + ) + return self.executed_value + + def __str__(self): + v = bool(self) + return f"JustknobsConfig(name={self.name}, env_name={self.env_name}, default={self.default} - evals_to={v})" + + def __bool__(self): + return self.get() + + +def justknobs_feature( + name: Optional[str], config_value=None, env_name=None, default: bool = True +): + """Returns whether or not a specific justknob feature is enabled. + + This is a slightly higher level API then justknobs_check, designed to make it "easy" to do the right thing. + The primary thing it does, is allow configuration to override JK by default, while retaining some features to force this + the other way during sevs. + + The preference order (i.e. who wins first) in OSS (and FB) is + - Config if specified + - Environment Variable if specified + - JK (FB), or default (OSS) + + + Quickstart + Have a config variable + Make a JK which is set to your "enabled" value (generally true). + Use this feature to check it (if you set the JK to be false, change the default). + If you have an env variable, also use the function to check it. + + Arguments: + name - This should correspond 1:1 to a JK name internally to FB. + env_name - If this is set, we'll try and read the value from environment variables + config_value - If this is set to anything other than None, we'll use this value by + default. Note that within FB, there is some functionality to force override these + configs + default - This is the value to return in OSS. This avoids having to write weird double + negatives within justknobs and the config code, if you just want to have the + killswitch work by having feature return True to turn off features + + Requirements: + WARNING - Don't use this at import time - Simply pass in the existing config. + If you want to use this at config time, use JustKnobsConfig + """ + if config_value is not None: + return config_value + if env_name is not None and ((env := os.getenv(env_name)) is not None): + env = env.upper() + if env in ("1", "TRUE"): + return True + if env in ("0", "FALSE"): + return False + log.error( + "Difficulty parsing env variable %s=%s for feature %s - Assuming env variable means true and returning True", + env_name, + env, + name, + ) + # We could return default here, but that was confusing to log. + return True + if name is None: + return True + if not default: + return not justknobs_check(name) + return justknobs_check(name) + + def justknobs_check(name: str) -> bool: """ This function can be used to killswitch functionality in FB prod, diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index c7fff5b98c7b45..063b57d859e75d 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -261,7 +261,8 @@ def load(self): self.append(cls.__new__(cls, *args)) else: raise UnpicklingError( - f"Trying to instantiate unsupported class {cls}" + "Can only create new object for nn.Parameter or classes allowlisted " + f"via `add_safe_globals` but got {cls}" ) elif key[0] == REDUCE[0]: args = self.stack.pop() @@ -291,7 +292,8 @@ def load(self): inst.__dict__.update(state) else: raise UnpicklingError( - f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}" + "Can only build Tensor, Parameter, OrderedDict or types allowlisted " + f"via `add_safe_globals`, but got {type(inst)}" ) # Stack manipulation elif key[0] == APPEND[0]: diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index f5a50bbe2b3e2c..6aba6bbad42efc 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -322,6 +322,15 @@ def __init__( raise RuntimeError( "Current CUDA Device does not support bfloat16. Please switch dtype to float16." ) + elif self.device == "mps": + supported_dtype = [torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In MPS autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "MPS Autocast only supports dtype of torch.bfloat16 currently." + ) + warnings.warn(error_message) + enabled = False elif self.device == "xla": supported_dtype = [torch.float16, torch.bfloat16] if self.fast_dtype not in supported_dtype: diff --git a/torch/ao/nn/intrinsic/modules/fused.py b/torch/ao/nn/intrinsic/modules/fused.py index 010b9a701a3353..c7ae2ce3319d3c 100644 --- a/torch/ao/nn/intrinsic/modules/fused.py +++ b/torch/ao/nn/intrinsic/modules/fused.py @@ -228,7 +228,7 @@ def __init__(self, conv, add): super().__init__(conv) self.add = add - def forward(self, x1, x2): + def forward(self, x1, x2): # type: ignore[override] return self.add(self[0](x1), x2) @@ -241,5 +241,5 @@ def __init__(self, conv, add, relu): self.add = add self.relu = relu - def forward(self, x1, x2): + def forward(self, x1, x2): # type: ignore[override] return self.relu(self.add(self[0](x1), x2)) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 3c700b30f24720..6ceee89757ca75 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import math -from typing import TypeVar +from typing import ClassVar, Optional, Type import torch import torch.ao.nn.intrinsic as nni @@ -33,12 +33,9 @@ } -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) - - class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule): _version = 2 - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -365,7 +362,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert mod.qconfig, "Input float module must have a valid qconfig" qconfig = mod.qconfig - conv, bn = mod[0], mod[1] + conv, bn = mod[0], mod[1] # type: ignore[index] qat_convbn = cls( conv.in_channels, conv.out_channels, @@ -434,7 +431,7 @@ def to_float(self): return conv -class ConvBn1d(_ConvBnNd, nn.Conv1d): +class ConvBn1d(_ConvBnNd, nn.Conv1d): # type: ignore[misc] r""" A ConvBn1d module is a module fused from Conv1d and BatchNorm1d, attached with FakeQuantize modules for weight, @@ -451,10 +448,10 @@ class ConvBn1d(_ConvBnNd, nn.Conv1d): weight_fake_quant: fake quant module for weight """ - _FLOAT_BN_MODULE = nn.BatchNorm1d - _FLOAT_RELU_MODULE: None = None - _FLOAT_MODULE = nni.ConvBn1d - _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBn1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d def __init__( self, @@ -520,12 +517,12 @@ class ConvBnReLU1d(ConvBn1d): """ # base class defines _FLOAT_MODULE as "ConvBn1d" - _FLOAT_MODULE = nni.ConvBnReLU1d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv1d - _FLOAT_BN_MODULE = nn.BatchNorm1d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvBnReLU1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm1d]] = nn.BatchNorm1d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU1d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d def __init__( self, @@ -585,10 +582,10 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU1d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv1d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU1d]] = nni.ConvReLU1d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, @@ -631,7 +628,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn2d(_ConvBnNd, nn.Conv2d): +class ConvBn2d(_ConvBnNd, nn.Conv2d): # type: ignore[misc] r""" A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, @@ -648,10 +645,10 @@ class ConvBn2d(_ConvBnNd, nn.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBn2d - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE = nn.BatchNorm2d - _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nni.ConvBn2d]] = nni.ConvBn2d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -717,12 +714,12 @@ class ConvBnReLU2d(ConvBn2d): """ # base class defines _FLOAT_MODULE as "ConvBn2d" - _FLOAT_MODULE = nni.ConvBnReLU2d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE = nn.BatchNorm2d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU2d]] = nni.ConvBnReLU2d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm2d]] = nn.BatchNorm2d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU # type: ignore[assignment,misc] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU2d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU2d]]] = nni.ConvReLU2d def __init__( self, @@ -782,10 +779,10 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU2d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv2d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nn.Module]] = nni.ConvReLU2d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, @@ -828,7 +825,7 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): ) -class ConvBn3d(_ConvBnNd, nn.Conv3d): +class ConvBn3d(_ConvBnNd, nn.Conv3d): # type: ignore[misc] r""" A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, @@ -845,10 +842,10 @@ class ConvBn3d(_ConvBnNd, nn.Conv3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBn3d - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE = nn.BatchNorm3d - _FLOAT_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nni.ConvBn3d]] = nni.ConvBn3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -913,12 +910,12 @@ class ConvBnReLU3d(ConvBn3d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvBnReLU3d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE = nn.BatchNorm3d - _FLOAT_RELU_MODULE = nn.ReLU # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nni.ConvBnReLU3d]] = nni.ConvBnReLU3d # type: ignore[assignment, misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Type[nn.BatchNorm3d]] = nn.BatchNorm3d + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.ReLU]]] = nn.ReLU # type: ignore[assignment, misc] # module class after fusing bn into conv - _FUSED_FLOAT_MODULE = nni.ConvReLU3d + _FUSED_FLOAT_MODULE: ClassVar[Optional[Type[nni.ConvReLU3d]]] = nni.ConvReLU3d def __init__( self, @@ -980,10 +977,10 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nni.ConvReLU3d # type: ignore[assignment] - _FLOAT_CONV_MODULE = nn.Conv3d - _FLOAT_BN_MODULE: None = None - _FLOAT_RELU_MODULE = nn.ReLU + _FLOAT_MODULE: ClassVar[Type[nni.ConvReLU3d]] = nni.ConvReLU3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _FLOAT_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _FLOAT_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nn.ReLU def __init__( self, diff --git a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py index 66299394091b7c..0d1b7e01f4479f 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/conv_add.py +++ b/torch/ao/nn/intrinsic/quantized/modules/conv_add.py @@ -49,7 +49,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): + def forward(self, input, extra_input): # type: ignore[override] # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: @@ -117,7 +117,7 @@ def __init__( dtype=dtype, ) - def forward(self, input, extra_input): + def forward(self, input, extra_input): # type: ignore[override] # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: diff --git a/torch/ao/nn/qat/modules/conv.py b/torch/ao/nn/qat/modules/conv.py index 3be5a7cec8162e..49b559136bbc6e 100644 --- a/torch/ao/nn/qat/modules/conv.py +++ b/torch/ao/nn/qat/modules/conv.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Tuple, TypeVar, Union +from typing import ClassVar, Tuple, Type, Union import torch import torch.nn as nn @@ -10,11 +10,9 @@ __all__ = ["Conv1d", "Conv2d", "Conv3d"] -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) - class _ConvNd(nn.modules.conv._ConvNd): - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -136,8 +134,8 @@ class Conv1d(_ConvNd, nn.Conv1d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv1d - _FLOAT_CONV_MODULE = nn.Conv1d + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d def __init__( self, @@ -197,8 +195,8 @@ class Conv2d(_ConvNd, nn.Conv2d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv2d - _FLOAT_CONV_MODULE = nn.Conv2d + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d def __init__( self, @@ -261,8 +259,8 @@ class Conv3d(_ConvNd, nn.Conv3d): Attributes: weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = nn.Conv3d - _FLOAT_CONV_MODULE = nn.Conv3d + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d # type: ignore[assignment,misc] + _FLOAT_CONV_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d def __init__( self, diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index fa86dcdd93e17d..f379e487f65de8 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -2,6 +2,7 @@ r"""Dynamically quantized convolution modules.""" import warnings +from typing import ClassVar, Optional, Type import torch import torch.ao.nn.quantized as nnq @@ -47,9 +48,9 @@ class Conv1d(nnq.Conv1d): """ - _FLOAT_MODULE = nn.Conv1d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -132,9 +133,9 @@ class Conv2d(nnq.Conv2d): >>> output = m(input) """ - _FLOAT_MODULE = nn.Conv2d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -216,9 +217,9 @@ class Conv3d(nnq.Conv3d): >>> output = m(input) """ - _FLOAT_MODULE = nn.Conv3d - _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment] - _NNI_CONV_RELU_MODULE = None # type: ignore[assignment] + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -308,7 +309,7 @@ class ConvTranspose1d(nnq.ConvTranspose1d): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose1d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d def __init__( self, @@ -390,7 +391,7 @@ class ConvTranspose2d(nnq.ConvTranspose2d): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose2d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d def __init__( self, @@ -472,7 +473,7 @@ class ConvTranspose3d(nnq.ConvTranspose3d): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose3d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/nn/quantized/modules/conv.py b/torch/ao/nn/quantized/modules/conv.py index 48ca18028acf94..0ef2f3af1dcd44 100644 --- a/torch/ao/nn/quantized/modules/conv.py +++ b/torch/ao/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs r"""Quantized convolution modules.""" -from typing import List, Optional, TypeVar +from typing import ClassVar, List, Optional, Type import torch import torch.ao.nn.intrinsic as nni @@ -386,11 +386,11 @@ class Conv1d(_ConvNd): """ - _FLOAT_MODULE = nn.Conv1d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d - _NNI_CONV_RELU_MODULE = nni.ConvReLU1d - _NNI_CONV_ADD_MODULE: None = None - _NNI_CONV_ADD_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nn.Conv1d]] = nn.Conv1d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn1d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU1d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -518,11 +518,11 @@ class Conv2d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE = nn.Conv2d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d - _NNI_CONV_RELU_MODULE = nni.ConvReLU2d - _NNI_CONV_ADD_MODULE = nni.ConvAdd2d - _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d + _FLOAT_MODULE: ClassVar[Type[nn.Conv2d]] = nn.Conv2d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn2d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU2d + _NNI_CONV_ADD_MODULE: ClassVar[Type[nni.ConvAdd2d]] = nni.ConvAdd2d + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Type[nni.ConvAddReLU2d]] = nni.ConvAddReLU2d def __init__( self, @@ -647,11 +647,11 @@ class Conv3d(_ConvNd): >>> output = m(q_input) """ - _FLOAT_MODULE = nn.Conv3d - _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d - _NNI_CONV_RELU_MODULE = nni.ConvReLU3d - _NNI_CONV_ADD_MODULE: None = None - _NNI_CONV_ADD_RELU_MODULE: None = None + _FLOAT_MODULE: ClassVar[Type[nn.Conv3d]] = nn.Conv3d + _NNIQAT_CONV_BN_MODULE: ClassVar[Optional[Type[nn.Module]]] = nniqat.ConvBn3d + _NNI_CONV_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = nni.ConvReLU3d + _NNI_CONV_ADD_MODULE: ClassVar[Optional[Type[nn.Module]]] = None + _NNI_CONV_ADD_RELU_MODULE: ClassVar[Optional[Type[nn.Module]]] = None def __init__( self, @@ -740,11 +740,10 @@ def from_float(cls, mod, use_precomputed_fake_quant=False): # === Transposed Convolutions === -MOD = TypeVar("MOD", bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): - _FLOAT_MODULE = MOD + _FLOAT_MODULE: ClassVar[Type[nn.modules.conv._ConvNd]] def __init__( self, @@ -914,7 +913,7 @@ class ConvTranspose1d(_ConvTransposeNd): torch.Size([1, 16, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose1d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose1d]] = nn.ConvTranspose1d def __init__( self, @@ -1037,7 +1036,7 @@ class ConvTranspose2d(_ConvTransposeNd): torch.Size([1, 16, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose2d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose2d]] = nn.ConvTranspose2d def __init__( self, @@ -1162,7 +1161,7 @@ class ConvTranspose3d(_ConvTransposeNd): torch.Size([1, 16, 12, 12, 12]) """ - _FLOAT_MODULE = nn.ConvTranspose3d + _FLOAT_MODULE: ClassVar[Type[nn.ConvTranspose3d]] = nn.ConvTranspose3d def __init__( self, diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 9a80b4af6f4794..9e4cf94303c9f1 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -198,7 +198,7 @@ def __init__(self): self.stats["float"] = [] self.stats["quantized"] = [] - def forward(self, x, y): + def forward(self, x, y): # type: ignore[override] # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index 3cea146b8031fd..da79b03ce88089 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -261,7 +261,7 @@ def __init__(self, *args, **kwargs): self.comparisons = [] # precalculated comparisons function - def forward(self, x, x_ref): + def forward(self, x, x_ref): # type: ignore[override] # fmt: off """ """ # blank docblock to make autodoc happy diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index 875af8c875b0a4..56784dd1780308 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -410,7 +410,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: # Add function swaps from default lowering path # - for source, ( + for source, ( # type:ignore[assignment] target1, target2, ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): @@ -422,7 +422,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: _lower_to_native_backend.QBIN_RELU_OP_MAPPING, quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, ): - for source, target in source_to_target.items(): + for source, target in source_to_target.items(): # type:ignore[assignment] new_connections.append((source, target)) # @@ -432,7 +432,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: for source_to_target in ( quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, ): - for source, target in source_to_target.items(): + for source, target in source_to_target.items(): # type:ignore[assignment] new_connections.append((source, target)) # add the new connections from backend_config diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index b81ee658e55318..75d86737d832d5 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -71,7 +71,7 @@ def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, **defaults # add data with default config here [self.add_data(name, data, **self.defaults) for name, data in data_list] - def prepare(self): + def prepare(self, model, config): raise NotImplementedError("this function is undefined for this class") def _extract_weight(self, data): @@ -266,7 +266,7 @@ def __getstate__(self): "_container": self._container.state_dict(), } - def __repr__(self): + def __repr__(self): # type:ignore[override] format_string = self.__class__.__name__ + " (" for name, sparse_args in self.data_groups.items(): format_string += "\n" @@ -299,7 +299,7 @@ def squash_mask(self, *args, leave_parametrized=True, names=None, **kwargs): self._container, name, leave_parametrized=leave_parametrized ) - def step(self): + def step(self): # type:ignore[override] if not self.enable_mask_update: return with torch.no_grad(): diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index 34444fa79df0fe..ec867205f64ab1 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -156,7 +156,7 @@ def __get_data_level_mask(self, data, sparsity_level, sparse_block_shape): ] # squeeze only the first 2 dimension return mask - def update_mask( + def update_mask( # type: ignore[override] self, name, data, sparsity_level, sparse_block_shape, zeros_per_block, **kwargs ): values_per_block = reduce(operator.mul, sparse_block_shape) diff --git a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py index 54177272ce4f2a..3da27ba38df55b 100644 --- a/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/FPGM_pruner.py @@ -75,7 +75,9 @@ def _compute_distance(self, t): return distance - def update_mask(self, module, tensor_name, sparsity_level, **kwargs): + def update_mask( # type: ignore[override] + self, module, tensor_name, sparsity_level, **kwargs + ): tensor_weight = getattr(module, tensor_name) mask = getattr(module.parametrizations, tensor_name)[0].mask diff --git a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py index 47c567ec6d93eb..a4d42ea803289c 100644 --- a/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py +++ b/torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py @@ -33,7 +33,9 @@ def __init__(self, nearliness: int = 1): defaults = {"nearliness": nearliness} super().__init__(defaults=defaults) - def update_mask(self, module, tensor_name, nearliness, **kwargs): + def update_mask( # type:ignore[override] + self, module, tensor_name, nearliness, **kwargs + ): mask = getattr(module.parametrizations, tensor_name)[0].mask mask.data = torch.zeros_like(mask) if nearliness <= 0: diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index 3c2eef0dd15c79..a25b7ffbca61cb 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -208,7 +208,7 @@ def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None) mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() return mask - def update_mask( + def update_mask( # type: ignore[call-override, override] self, module, tensor_name, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index adf70f79e20443..503755d383bb93 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -11,6 +11,7 @@ from .observer import * # noqa: F403 from .pt2e._numeric_debugger import ( # noqa: F401 compare_results, + CUSTOM_KEY, extract_results_from_loggers, generate_numeric_debug_handle, NUMERIC_DEBUG_HANDLE_KEY, @@ -162,6 +163,7 @@ "swap_module", "weight_observer_range_neg_127_to_127", "generate_numeric_debug_handle", + "CUSTOM_KEY", "NUMERIC_DEBUG_HANDLE_KEY", "prepare_for_propagation_comparison", "extract_results_from_loggers", @@ -214,5 +216,5 @@ def __init__( def forward(self, x: Tensor) -> Tensor: return x - def calculate_qparams(self): + def calculate_qparams(self): # type:ignore[override] return self.derive_qparams_fn(self.obs_or_fqs) diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 4be37467de1818..03d259a0ad26ca 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -61,7 +61,7 @@ def __init__(self): self.float_sum = None self.quant_sum = None - def forward(self, x, y): + def forward(self, x, y): # type: ignore[override] """Compute the average of quantized and floating-point data from modules. The inputs x,y are output data from the quantized and floating-point modules. diff --git a/torch/ao/quantization/experimental/linear.py b/torch/ao/quantization/experimental/linear.py index 34b0ca8e3921db..0093550472e0ca 100644 --- a/torch/ao/quantization/experimental/linear.py +++ b/torch/ao/quantization/experimental/linear.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import numpy as np +import numpy.typing as npt import torch from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule @@ -148,7 +149,7 @@ def forward(self, activation: torch.Tensor) -> torch.FloatTensor: weight_rows = self.weight_transposed.size()[0] weight_cols = self.weight_transposed.size()[1] - decomposed_weight: np.ndarray = np.empty( + decomposed_weight: npt.NDArray = np.empty( shape=(weight_rows, weight_cols), dtype=object ) for row in range(weight_rows): diff --git a/torch/ao/quantization/experimental/observer.py b/torch/ao/quantization/experimental/observer.py index 631ec11b1669d1..1cf27c97054c1e 100644 --- a/torch/ao/quantization/experimental/observer.py +++ b/torch/ao/quantization/experimental/observer.py @@ -34,7 +34,7 @@ def __init__(self, b, k, dtype=torch.quint8) -> None: # min_val and max_val are optional args to override # the min_val and max_val observed by forward - def calculate_qparams(self, signed): + def calculate_qparams(self, signed): # type:ignore[override] return self._calculate_qparams(signed, self.min_val, self.max_val) r""" Calculates nonuniform quantization parameters according to APoT paper: diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 4502d180cfc655..83487374845330 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch -from torch.ao.quantization import NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY from torch.ao.quantization.backend_config import ( BackendConfig, get_native_backend_config, @@ -229,10 +229,15 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): node.replace_all_uses_with(dequantized_node) # propagate numeric debug handle from observer/fake_quant node to dequantize node - if NUMERIC_DEBUG_HANDLE_KEY in node.meta: - dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ - NUMERIC_DEBUG_HANDLE_KEY - ] + if ( + CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] + ): + if CUSTOM_KEY not in dequantized_node.meta: + dequantized_node.meta[CUSTOM_KEY] = {} + dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] graph.erase_node(node) elif is_dynamic: # uint8/int8/fp16 dynamic quantization diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index e26f03027116e2..9b4e26743ab677 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -253,6 +253,7 @@ def __init__( torch.int32, torch.float8_e5m2, torch.float8_e4m3fn, + torch.uint16, ) assert ( @@ -368,6 +369,8 @@ def _calculate_qparams( ) else: zero_point = zero_point.new_full(zero_point.size(), 128) + elif self.dtype in [torch.uint16]: + zero_point = zero_point.new_full(zero_point.size(), 2**15) elif self.qscheme == torch.per_channel_affine_float_qparams: scale = (max_val - min_val) / float(quant_max - quant_min) scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py index 1bcc442ca441a9..fedcf470a18a1d 100644 --- a/torch/ao/quantization/pt2e/_numeric_debugger.py +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -9,7 +9,8 @@ from torch.nn import functional as F -NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle" +NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle" +CUSTOM_KEY = "custom" log = logging.getLogger(__name__) @@ -20,8 +21,14 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None: """ unique_id = 0 for node in graph_module.graph.nodes: - if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta: - node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id + if node.op in ["output", "placeholder"]: + continue + + if CUSTOM_KEY not in node.meta: + node.meta[CUSTOM_KEY] = {} + + if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]: + node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id unique_id += 1 @@ -98,9 +105,12 @@ def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: - if NUMERIC_DEBUG_HANDLE_KEY not in n.meta: + if ( + CUSTOM_KEY not in n.meta + or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY] + ): continue - numeric_debug_handle = n.meta[NUMERIC_DEBUG_HANDLE_KEY] + numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] _insert_logger(model, n, numeric_debug_handle) model.recompile() diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index d758870017ca59..290e3d771c3230 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -4,6 +4,7 @@ import torch from torch._subclasses import FakeTensor from torch.ao.quantization import ( + CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY, ObserverOrFakeQuantize, QConfigMapping, @@ -97,20 +98,14 @@ def _unwrap_shared_qspec( return qspec -def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): - return ( - hasattr(qspec_a, "dtype") - and hasattr(qspec_b, "dtype") - and qspec_a.dtype == qspec_b.dtype - ) - - -def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase): +def _has_same_attr( + qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str +): return ( - hasattr(qspec_a, "is_dynamic") - and hasattr(qspec_b, "is_dynamic") - and qspec_a.is_dynamic == qspec_b.is_dynamic - ) + hasattr(qspec_a, attr_name) + and hasattr(qspec_b, attr_name) + and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name) + ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name)) def _get_edge_or_node_to_qspec( @@ -147,10 +142,18 @@ def _union_input_edge_with( qspec = edge_or_node_to_qspec[edge_or_node] root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map) # TODO: add assertions for types of root qspecs - if ( - root_qspec is not None - and _has_same_dtype(root_qspec, input_edge_root_qspec) - and _has_same_is_dynamic(root_qspec, input_edge_root_qspec) + if root_qspec is not None and all( + _has_same_attr(root_qspec, input_edge_root_qspec, attr) + for attr in [ + "dtype", + "is_dynamic", + "quant_min", + "quant_max", + "qscheme", + "ch_axis", + "scale", + "zero_point", + ] ): # the input arg to the node should reuse the existing output observer for arg # since dtype is the same (we may want to extend this to be a more strict check @@ -459,11 +462,14 @@ def _maybe_insert_output_observer_for_node( if ( isinstance(node, Node) and isinstance(new_output, Node) - and NUMERIC_DEBUG_HANDLE_KEY in node.meta + and CUSTOM_KEY in node.meta + and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] ): - new_output.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ - NUMERIC_DEBUG_HANDLE_KEY - ] + if CUSTOM_KEY not in new_output.meta: + new_output.meta[CUSTOM_KEY] = {} + new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[ + CUSTOM_KEY + ][NUMERIC_DEBUG_HANDLE_KEY] return new_output return None diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 763c181be0caa6..a867acbeb1cb86 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -281,7 +281,7 @@ def get_default_qconfig(backend="x86", version=0): weight=default_weight_observer, ) elif backend == "onednn": - if not torch.cpu._is_cpu_support_vnni(): + if not torch.cpu._is_vnni_supported(): warnings.warn( "Default qconfig of oneDNN backend with reduce_range of false may have accuracy issues " "on CPU without Vector Neural Network Instruction support." diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 574af30a7159b9..6042fd2ee5adb9 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -225,7 +225,7 @@ def _map_module_function_to_aten_operator_type(): ), ) for map_item in map_list: - module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[call-overload] + module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload] return module_function_to_aten_operator diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index d243932a1ef82d..293e37ed456fb4 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -473,6 +473,10 @@ def calculate_qmin_qmax( quant_min, quant_max = 0, 255 elif dtype in [torch.qint32, torch.int32]: quant_min, quant_max = -1 * (2**31), (2**31) - 1 + elif dtype in [torch.uint16]: + quant_min, quant_max = 0, 2**16 - 1 + elif dtype in [torch.int16]: + quant_min, quant_max = -(2**15), 2**15 - 1 else: quant_min, quant_max = 0, 15 return quant_min, quant_max diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index aad31e223b174b..4792ced32b2703 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -108,6 +108,13 @@ def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: See :ref:`backward-hooks-execution` for more information on how when this hook is executed, and how its execution is ordered relative to other hooks. + .. note:: + In the rare case where the hook is registered while the Node has already + begun execution, there is no longer any guarantee on :attr:`grad_outputs` + content (it might be as usual or empty depending on other factors). The + hook can still optionally return a new gradient to be used in place of + :attr:`grad_inputs` independent of :attr:`grad_outputs`. + Example:: >>> import torch @@ -179,7 +186,8 @@ def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node: if isinstance(t, GradientEdge): return t.node if t.requires_grad and t.grad_fn is None: - node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] + with torch.enable_grad(): + node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] else: node = t.grad_fn assert node is not None diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 0cd81e4c208599..3957c044cbbca4 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -501,6 +501,9 @@ def __init__( self.is_legacy: bool = is_legacy self.flops: Optional[int] = flops self.is_user_annotation: Optional[bool] = is_user_annotation + self.self_cpu_percent = -1 + self.total_cpu_percent = -1 + self.total_device_percent = -1 def append_kernel(self, name, device, duration): assert self.device_type == DeviceType.CPU @@ -1017,26 +1020,33 @@ def trim_path(path, src_column_width): name = evt.key if max_name_column_width is not None and len(name) >= max_name_column_width - 3: name = name[: (max_name_column_width - 3)] + "..." + evt.self_cpu_percent = _format_time_share( + evt.self_cpu_time_total, sum_self_cpu_time_total + ) + evt.total_cpu_percent = ( + _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) + if not evt.is_async + else 0 + ) row_values = [ name, # Self CPU total %, 0 for async events. - _format_time_share(evt.self_cpu_time_total, sum_self_cpu_time_total), + evt.self_cpu_percent, evt.self_cpu_time_total_str, # Self CPU total # CPU total %, 0 for async events. - _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total) - if not evt.is_async - else 0, + evt.total_cpu_percent, evt.cpu_time_total_str, # CPU total evt.cpu_time_str, # CPU time avg ] if has_device_time: + evt.total_device_percent = _format_time_share( + evt.self_device_time_total, sum_self_device_time_total + ) row_values.extend( [ evt.self_device_time_total_str, # device time total % - _format_time_share( - evt.self_device_time_total, sum_self_device_time_total - ), + evt.total_device_percent, evt.device_time_total_str, evt.device_time_str, # device time avg ] diff --git a/torch/backends/cpu/__init__.py b/torch/backends/cpu/__init__.py index cc8a174626bf17..82dc52cd4904c1 100644 --- a/torch/backends/cpu/__init__.py +++ b/torch/backends/cpu/__init__.py @@ -16,5 +16,6 @@ def get_cpu_capability() -> str: - "NO AVX" - "AVX2" - "AVX512" + - "SVE256" """ return torch._C._get_cpu_capability() diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 406b20dc779550..7da8e911b83b22 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -123,6 +123,7 @@ def fn(a): def substitute_in_graph( original_fn: _F, *, + can_constant_fold_through: bool = False, skip_signature_check: bool = False, ) -> Callable[[_F], _F]: """ @@ -142,6 +143,10 @@ def substitute_in_graph( Args: original_fn (callable): The original function, usually a C function, to register a polyfill handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. skip_signature_check (bool, optional): Whether to skip the signature check between the original function and the polyfill handler. Defaults to ``False``. @@ -173,6 +178,7 @@ def substitute_in_graph( return torch._dynamo.substitute_in_graph( original_fn, + can_constant_fold_through=can_constant_fold_through, skip_signature_check=skip_signature_check, ) diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 978405c0410c69..8443e0447aa25a 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -29,25 +29,30 @@ _device_t = Union[_device, str, int, None] -def _is_cpu_support_avx2() -> bool: +def _is_avx2_supported() -> bool: r"""Returns a bool indicating if CPU supports AVX2.""" - return torch._C._cpu._is_cpu_support_avx2() + return torch._C._cpu._is_avx2_supported() -def _is_cpu_support_avx512() -> bool: +def _is_avx512_supported() -> bool: r"""Returns a bool indicating if CPU supports AVX512.""" - return torch._C._cpu._is_cpu_support_avx512() + return torch._C._cpu._is_avx512_supported() -def _is_cpu_support_vnni() -> bool: +def _is_avx512_bf16_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX512_BF16.""" + return torch._C._cpu._is_avx512_bf16_supported() + + +def _is_vnni_supported() -> bool: r"""Returns a bool indicating if CPU supports VNNI.""" # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. - return torch._C._cpu._is_cpu_support_avx512_vnni() + return torch._C._cpu._is_avx512_vnni_supported() -def _is_cpu_support_amx_tile() -> bool: +def _is_amx_tile_supported() -> bool: r"""Returns a bool indicating if CPU supports AMX_TILE.""" - return torch._C._cpu._is_cpu_support_amx_tile() + return torch._C._cpu._is_amx_tile_supported() def _init_amx() -> bool: diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 35dbb0854a8545..19433d62985fd6 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -69,6 +69,7 @@ #include #include #include +#include #include #include #include @@ -170,7 +171,7 @@ static PyObject* THPModule_initExtension( PyObject* _unused, PyObject* shm_manager_path) { HANDLE_TH_ERRORS -#if !defined(FBCODE_CAFFE2) +#if !defined(FBCODE_CAFFE2) && !defined(__aarch64__) if (torch::get_cpp_stacktraces_enabled()) { c10::SetStackTraceFetcher([]() -> std::string { auto tb = torch::CapturedTraceback::gather(false, false, true); @@ -1687,6 +1688,7 @@ PyObject* initModule() { torch::python::init_bindings(module); torch::lazy::initLazyBindings(module); torch::inductor::initAOTIRunnerBindings(module); + torch::inductor::initAOTIPackageBindings(module); #ifdef USE_ITT torch::profiler::initIttBindings(module); #endif diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 77520b6f1cdb1f..5e029fb0b6bd88 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -334,37 +335,38 @@ static PyObject* THPStorage_pynew( allocator = reinterpret_cast(allocator_opt.value()); } else if (device_opt.has_value()) { at::Device device = device_opt.value(); - if (device.type() == at::kCPU) { - allocator = c10::GetDefaultCPUAllocator(); + torch::utils::maybe_initialize_device(device); + + switch (device.type()) { + case at::kCPU: + allocator = c10::GetDefaultCPUAllocator(); + break; #ifdef USE_CUDA - } else if (device.type() == at::kCUDA) { - at::globalContext().lazyInitCUDA(); - allocator = c10::cuda::CUDACachingAllocator::get(); + case at::kCUDA: + allocator = c10::cuda::CUDACachingAllocator::get(); + break; #endif #ifdef USE_MPS - } else if (device.type() == at::kMPS) { - allocator = at::mps::GetMPSAllocator(); + case at::kMPS: + allocator = at::mps::GetMPSAllocator(); + break; #endif - // NOLINTBEGIN(bugprone-branch-clone) - } else if (device.type() == at::DeviceType::XPU) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::HPU) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::Meta) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::PrivateUse1) { - at::globalContext().lazyInitPrivateUse1(); - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::MAIA) { - allocator = c10::GetAllocator(device.type()); - } else { - // NOLINTEND(bugprone-branch-clone) - TORCH_CHECK( - false, - THPStorageStr, - "(): Storage device not recognized: ", - device.type()); + case at::DeviceType::XPU: + case at::DeviceType::HPU: + case at::DeviceType::Meta: + case at::DeviceType::PrivateUse1: + case at::DeviceType::MAIA: + allocator = c10::GetAllocator(device.type()); + break; + default: + // NOLINTEND(bugprone-branch-clone) + TORCH_CHECK( + false, + THPStorageStr, + "(): Storage device not recognized: ", + device.type()); } + device_guard.reset_device(device); } else { allocator = c10::GetDefaultCPUAllocator(); diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 4bd16066143fe4..a9d3f64f914550 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -49,9 +49,6 @@ static PyObject* THPStorage_nbytes(PyObject* self, PyObject* noargs) { static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // PyLong_FromVoidPtr should not need to mutate the pointer in order - // to extract a new long object from it. - auto self_ = THPStorage_Unpack(self); // See Note [Invalid Python Storages] auto invalid = self_.data() == nullptr && @@ -59,7 +56,7 @@ static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) { TORCH_CHECK( !invalid, "Attempted to access the data pointer on an invalid python storage.") - return PyLong_FromVoidPtr(self_.mutable_data()); + return torch::autograd::utils::wrap(self_.mutable_data()); END_HANDLE_TH_ERRORS } @@ -190,6 +187,16 @@ static PyObject* THPStorage_fill_(PyObject* self, PyObject* number_arg) { END_HANDLE_TH_ERRORS } +template +static void decodeWrapper( + void* data, + const uint8_t* src, + bool do_byte_swap, + size_t count) { + torch::utils::THP_decodeBuffer( + static_cast(data), src, do_byte_swap, count); +} + static PyObject* THPStorage_fromBuffer( PyObject* _unused, PyObject* args, @@ -308,70 +315,30 @@ static PyObject* THPStorage_fromBuffer( c10::GetDefaultCPUAllocator(), /*resizable=*/true); + static const std::unordered_map< + at::ScalarType, + std::function> + decode_map = { + {at::kBool, decodeWrapper}, + {at::kShort, decodeWrapper}, + {at::kInt, decodeWrapper}, + {at::kLong, decodeWrapper}, + {at::kHalf, decodeWrapper}, + {at::kBFloat16, decodeWrapper}, + {at::kFloat, decodeWrapper}, + {at::kDouble, decodeWrapper}, + {at::kComplexFloat, decodeWrapper>}, + {at::kComplexDouble, decodeWrapper>}}; + if (is_endian_independent) { memcpy(storage->mutable_data(), src + offset, count); - } else if (scalar_type == at::kBool) { - // Because of ASAN checks, that are failing whenever - // we are trying to get a value which is not 0 or 1, we have to manually - // convert original values to boolean ones. - torch::utils::THP_decodeBoolBuffer( - static_cast(storage->mutable_data()), src + offset, count); - } else if (scalar_type == at::kShort) { - torch::utils::THP_decodeInt16Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kInt) { - torch::utils::THP_decodeInt32Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kLong) { - torch::utils::THP_decodeInt64Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kHalf) { - torch::utils::THP_decodeHalfBuffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kBFloat16) { - torch::utils::THP_decodeBFloat16Buffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kFloat) { - torch::utils::THP_decodeFloatBuffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kDouble) { - torch::utils::THP_decodeDoubleBuffer( - static_cast(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kComplexFloat) { - torch::utils::THP_decodeComplexFloatBuffer( - static_cast*>(storage->mutable_data()), - src + offset, - do_byte_swap, - count); - } else if (scalar_type == at::kComplexDouble) { - torch::utils::THP_decodeComplexDoubleBuffer( - static_cast*>(storage->mutable_data()), - src + offset, - do_byte_swap, - count); } else { - TORCH_CHECK(false, "Unknown type: ", scalar_type); + auto it = decode_map.find(scalar_type); + if (it != decode_map.end()) { + it->second(storage->mutable_data(), src + offset, do_byte_swap, count); + } else { + TORCH_CHECK(false, "Unknown type: ", scalar_type); + } } PyBuffer_Release(&buffer); diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 37b1a7bd5983f6..3f24c6ecb40951 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -897,7 +897,7 @@ Tensor logsumexp_backward( grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size()); result = unsqueeze_multiple(result, dim, self.sym_sizes().size()); } - return grad * (self - result).exp(); + return grad * (self - result).exp().conj(); } Tensor logcumsumexp_backward( @@ -6689,7 +6689,8 @@ Tensor logsumexp_jvp( // forward auto self_p_exp = [&self_p, &dim]() { if (self_p.sym_numel() > 0) { - return (self_p - at::amax(self_p, dim, true)) + // Use only the real part for complex tensors + return (self_p - at::amax(at::real(self_p), dim, true)) .exp(); // Use the exp-normalize trick } else { // amax fails if numel() == 0, in which case it doesn't matter anyway diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 2f53924fd9111d..a79ee3d1608777 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -194,8 +194,9 @@ struct CppNode : public Node { if (!T::is_traceable) { throw std::runtime_error( std::string( - "compiled_args not implemented for non-traceable node: ") + - name()); + "Attempting to trace a potentially unsafe C++ autograd function: ") + + name() + + ". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/."); } // although neither of the 2 methods below have uniqueness guarantees diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index de13aac33b38e8..1be6242909af71 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -819,9 +819,15 @@ static variable_list call_tensor_pre_hooks(Node& fn, variable_list inputs) { static variable_list call_post_hooks( Node& fn, variable_list outputs, - const variable_list& inputs) { + const variable_list& inputs, + const bool had_post_hooks) { for (const auto& hook : fn.post_hooks()) { - outputs = (*hook)(outputs, inputs); + if (had_post_hooks) { + outputs = (*hook)(outputs, inputs); + } else { + variable_list null_inputs; + outputs = (*hook)(outputs, null_inputs); + } } return outputs; } @@ -976,11 +982,8 @@ static variable_list call_function( return ss.str(); }); - if (has_post_hooks) { - // NOLINTNEXTLINE(bugprone-use-after-move) - return call_post_hooks(fn, std::move(outputs), inputs); - } - return outputs; + // NOLINTNEXTLINE(bugprone-use-after-move) + return call_post_hooks(fn, std::move(outputs), inputs, has_post_hooks); } void Engine::evaluate_function( diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index c7a7a9d049ec19..10d1c2e7ef786c 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -261,7 +261,22 @@ struct AddGenericMetadata : public MetadataBase { // Add metadata for kwinputs if exist for (const auto& [key, val] : op_event.kwinputs_) { - addMetadata(key, ivalueToStr(val)); + if (key == "stream" && !val.isInt()) { + LOG(WARNING) << "Inputted stream is not an int for op: " + << op_event.name_ << " skipping"; + continue; + } + + // Until needed, lets limit the kwargs to only ints, doubles, strings and + // bools + if (!val.isInt() && !val.isDouble() && !val.isString() && !val.isBool()) { + LOG(WARNING) << "Inputted kwarg: " << key + << " is not an int, double, string, or bool for op: " + << op_event.name_ << " skipping"; + continue; + } + bool isString = val.isString(); + addMetadata(key, ivalueToStr(val, isString)); } // Add extra metadata if any for (const auto& [key, val] : op_event.extra_meta_) { diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 1feb1f41cdd916..92890a1509e8ef 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -664,6 +664,10 @@ void initTorchFunctions(PyObject* module) { !at::functionalization::impl::isFunctionalTensor(o)); at::functionalization::impl::replace_(t, o); }); + py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) { + TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); + return at::functionalization::impl::isBaseTensor(t); + }); py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) { TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index 5d1d17c597af32..72d7a6c76d7441 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -47,7 +47,7 @@ inline PyObject* wrap(c10::complex value) { } inline PyObject* wrap(void* value) { - return THPUtils_packInt64(reinterpret_cast(value)); + return PyLong_FromVoidPtr(value); } inline PyObject* wrap(THPDtype* dtype) { diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index 700623782fe47d..84eb864d2ceca9 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -8,10 +8,11 @@ void initModule(PyObject* module) { auto m = py::handle(module).cast(); auto cpu = m.def_submodule("_cpu", "cpu related pybind."); - cpu.def("_is_cpu_support_avx2", at::cpu::is_cpu_support_avx2); - cpu.def("_is_cpu_support_avx512", at::cpu::is_cpu_support_avx512); - cpu.def("_is_cpu_support_avx512_vnni", at::cpu::is_cpu_support_avx512_vnni); - cpu.def("_is_cpu_support_amx_tile", at::cpu::is_cpu_support_amx_tile); + cpu.def("_is_avx2_supported", at::cpu::is_avx2_supported); + cpu.def("_is_avx512_supported", at::cpu::is_avx512_supported); + cpu.def("_is_avx512_vnni_supported", at::cpu::is_avx512_vnni_supported); + cpu.def("_is_avx512_bf16_supported", at::cpu::is_avx512_bf16_supported); + cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported); cpu.def("_init_amx", at::cpu::init_amx); cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size); cpu.def("_L2_cache_size", at::cpu::L2_cache_size); diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index c6af163481481e..5220e86233bd67 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -210,8 +210,8 @@ void CUDAPluggableAllocator::recordStream( } } -c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: - getDeviceStats(c10::DeviceIndex device) { +c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( + c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 70014b2fa0b37b..8652ef0f2bfde8 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -115,7 +115,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator void recordStream(const c10::DataPtr&, streamType stream) override; - c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device) override; void resetAccumulatedStats(c10::DeviceIndex device) override; void resetPeakStats(c10::DeviceIndex device) override; diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index b73316cfc1778f..0bb76907ee0f7c 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -123,10 +123,12 @@ static PyObject* THCPEvent_get_device(THCPEvent* self, void* unused) { } static PyObject* THCPEvent_record(PyObject* _self, PyObject* _stream) { - HANDLE_TH_ERRORS - auto self = (THCPEvent*)_self; - auto stream = (THCPStream*)_stream; - self->cuda_event.record(stream->cuda_stream); + HANDLE_TH_ERRORS { + auto self = (THCPEvent*)_self; + auto stream = (THCPStream*)_stream; + pybind11::gil_scoped_release no_gil{}; + self->cuda_event.record(stream->cuda_stream); + } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 1cc6f7378e80c4..461a23e651924b 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -542,9 +542,20 @@ PyObject* THCPModule_setMemoryFraction(PyObject* _unused, PyObject* args) { Py_RETURN_NONE; } +PyObject* THCPModule_hostEmptyCache(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + at::cuda::CachingHostAllocator_emptyCache(); + } + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + PyObject* THCPModule_emptyCache(PyObject* _unused, PyObject* noargs) { - HANDLE_TH_ERRORS - c10::cuda::CUDACachingAllocator::emptyCache(); + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + c10::cuda::CUDACachingAllocator::emptyCache(); + } END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -554,10 +565,10 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const auto device_index = THPUtils_unpackDeviceIndex(arg); - using c10::cuda::CUDACachingAllocator::DeviceStats; - using c10::cuda::CUDACachingAllocator::Stat; - using c10::cuda::CUDACachingAllocator::StatArray; - using c10::cuda::CUDACachingAllocator::StatType; + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; const auto statToDict = [](const Stat& stat) { py::dict dict; @@ -978,11 +989,10 @@ static void registerCudaDeviceProperties(PyObject* module) { "max_threads_per_multi_processor", &cudaDeviceProp::maxThreadsPerMultiProcessor) .def_readonly("warp_size", &cudaDeviceProp::warpSize) -#if !USE_ROCM - // NVIDA only property +#if (defined(USE_ROCM) && ROCM_VERSION >= 60100) || !USE_ROCM .def_readonly( "regs_per_multiprocessor", &cudaDeviceProp::regsPerMultiprocessor) -#endif // USE_ROCM +#endif // HIP-only property; reuse name attribute for CUDA builds .def_readonly( "gcnArchName", @@ -1281,6 +1291,13 @@ static void registerCudaPluggableAllocator(PyObject* module) { }); }); + m.def( + "_cuda_beginAllocateToPool", + [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) { + c10::cuda::CUDACachingAllocator::beginAllocateToPool( + device, mempool_id, [](cudaStream_t) { return true; }); + }); + m.def( "_cuda_endAllocateCurrentStreamToPool", [](c10::DeviceIndex device, at::cuda::MempoolId_t mempool_id) { @@ -1830,6 +1847,7 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cudaHostAllocator, METH_NOARGS, nullptr}, + {"_host_emptyCache", THCPModule_hostEmptyCache, METH_NOARGS, nullptr}, {"_cuda_cudaCachingAllocator_raw_alloc", THCPModule_cudaCachingAllocator_raw_alloc, METH_VARARGS, diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index afc904f2562d1b..22bcb44109a49d 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -832,7 +832,10 @@ void all2all_single_equal_split( const auto* sendbuff = reinterpret_cast(input.const_data_ptr()); auto* recvbuff = reinterpret_cast(output.data_ptr()); auto comm = to_nccl_comm(_comm); -#if defined(USE_ROCM) +#if defined(USE_ROCM) || defined(NCCL_ALLTOALL_SUPPORTED) + // NCCL_ALLTOALL_SUPPORTED is used so NCCL can differentiate send/recv + // operations issued as a part of the collective (e.g. alltoall) vs those + // inside traditional p2p operations. NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream)); #else NCCL_CHECK(ncclCommCount(comm, &numranks)); @@ -877,6 +880,21 @@ void all2all_single_unequal_split( auto type = to_nccl_data_type(_type); auto comm = to_nccl_comm(_comm); +#ifdef NCCL_ALLTOALLV_SUPPORTED + // NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv + // operations issued as a part of the collective (e.g. alltoallv) vs those + // inside traditional p2p operations. + NCCL_CHECK(ncclAllToAllv( + sendbuff, + sendcounts, + senddispls, + recvbuff, + recvcounts, + recvdispls, + type, + comm, + stream.stream())); +#else int numranks; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); @@ -905,6 +923,7 @@ void all2all_single_unequal_split( #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif +#endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif @@ -924,6 +943,48 @@ void all2all( using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); +#ifdef NCCL_ALLTOALLV_SUPPORTED + // NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv + // operations issued as a part of the collective (e.g. alltoallv) vs those + // inside traditional p2p operations. + TORCH_INTERNAL_ASSERT( + outputTensors.size() == inputTensors.size(), + "number of input tensors is not equal to number of output tensors"); + std::vector sendCounts(inputTensors.size()); + std::vector sendDisps(inputTensors.size()); + std::vector recvCounts(outputTensors.size()); + std::vector recvDisps(outputTensors.size()); + uintptr_t sendBase = reinterpret_cast(inputTensors[0].data_ptr()); + uintptr_t recvBase = reinterpret_cast(outputTensors[0].data_ptr()); + size_t dtypeSize = inputTensors.front().element_size(); + + for (const auto r : c10::irange(outputTensors.size())) { + sendCounts[r] = inputTensors[r].numel(); + auto sendOffset = + reinterpret_cast(inputTensors[r].data_ptr()) - sendBase; + TORCH_INTERNAL_ASSERT( + sendOffset % dtypeSize == 0, + "sendOffset is not divisible by dtypeSize"); + sendDisps[r] = sendOffset / dtypeSize; + recvCounts[r] = outputTensors[r].numel(); + auto recvOffset = + reinterpret_cast(outputTensors[r].data_ptr()) - recvBase; + TORCH_INTERNAL_ASSERT( + recvOffset % dtypeSize == 0, + "recvOffset is not divisible by dtypeSize"); + recvDisps[r] = recvOffset / dtypeSize; + } + NCCL_CHECK(ncclAllToAllv( + inputTensors[0].data_ptr(), + sendCounts.data(), + sendDisps.data(), + outputTensors[0].data_ptr(), + recvCounts.data(), + recvDisps.data(), + to_nccl_data_type(inputTensors.front()), + comm, + stream.stream())); +#else NCCL_CHECK(ncclGroupStart()); for (const auto r : c10::irange(outputTensors.size())) { at::Tensor& input = inputTensors[r]; @@ -953,6 +1014,7 @@ void all2all( #else NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif +#endif #else AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); #endif diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 46a2426fd5f4c0..47ace12db6c3fc 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -21,7 +21,7 @@ constexpr int64_t kCommInitBusyWaitMillis = 10; namespace c10d { ncclComm_t NCCLComm::getNcclComm() { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (aborted_) { auto commFailureMsg = commFailureReason_ != std::nullopt ? c10::str(" Original reason for failure was: ", *commFailureReason_) @@ -84,7 +84,9 @@ std::shared_ptr NCCLComm::split( std::nullopt); ++source->ncclCommSplitCounter_; comm->rank_ = rank; - comm->initialized_ = true; + if (!nccl_use_nonblocking()) { + comm->initialized_ = true; + } return comm; } #endif @@ -391,7 +393,7 @@ std::optional NCCLTraceBuffer::record( } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); auto te = Entry{ id_, @@ -448,7 +450,7 @@ void NCCLTraceBuffer::record_pg_ranks( if (!enabled_) { return; } - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); pg_name_to_ranks_[pg_name] = ranks; } @@ -468,7 +470,7 @@ void NCCLTraceBuffer::update_state(Entry& r) { } std::vector NCCLTraceBuffer::dump_entries() { - C10D_LOCK_GUARD(guard, mutex_); + std::lock_guard guard(mutex_); std::vector result; result.reserve(entries_.size()); result.insert(result.end(), entries_.begin() + next_, entries_.end()); @@ -493,7 +495,7 @@ void NCCLTraceBuffer::retire_id( Event* endEvent = nullptr; std::optional duration = std::nullopt; - C10D_LOCK_GUARD(guard, mutex_); + std::unique_lock guard(mutex_); Entry* entry = &entries_.at(*id % max_entries_); if (entry->id_ == *id) { @@ -534,6 +536,7 @@ const c10::List NCCLTraceBuffer::getCollectiveTrace( bool includeStacktraces, bool onlyActive) { auto entries = new_list(); + // Entries are returned in the order they were recorded auto result = dump_entries(); std::vector tracebacks; torch::SymbolizedTracebacks stracebacks; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index dabd5a7d21429d..070cbd34b37970 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ @@ -177,14 +176,14 @@ namespace c10d { #define DEFINE_CONSTANT(name, value) \ static c10::IValue name = value; \ static std::string name##_str = value; -DEFINE_CONSTANT(entries_key, "entries"); -DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state"); -DEFINE_CONSTANT(version_key, "version"); // Update whenever changing contents or formatting of the dump // (minor when adding fields, major when changing existing fields) // Also update both JSON and Pickle dumps to make use of the newly defined // field(s). -DEFINE_CONSTANT(version_val, "2.3"); +DEFINE_CONSTANT(version_val, "2.4"); +DEFINE_CONSTANT(entries_key, "entries"); +DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state"); +DEFINE_CONSTANT(version_key, "version"); DEFINE_CONSTANT(pg_config_key, "pg_config"); DEFINE_CONSTANT(pg_status_key, "pg_status"); DEFINE_CONSTANT(record_id_key, "record_id"); @@ -271,7 +270,7 @@ class NCCLComm { ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory // barrier here. - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (ncclComm_ && initialized_ && !aborted_) { #ifdef ENABLE_NCCL_ERROR_CHECKING // Use ncclCommAbort instead of ncclCommDestroy here since @@ -363,7 +362,7 @@ class NCCLComm { NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. - C10D_LOCK_GUARD(lock, other.mutex_); + std::unique_lock lock(other.mutex_); std::swap(ncclComm_, other.ncclComm_); std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); @@ -373,13 +372,13 @@ class NCCLComm { ncclComm_t getNcclComm(); std::optional getNcclCommFailureReason() const { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); return commFailureReason_; } void ncclCommAbort( std::optional commFailureReason = std::nullopt) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_ && !initialized_) { // Should not abort twice. @@ -427,7 +426,7 @@ class NCCLComm { } bool isAborted() const { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); return aborted_; } @@ -436,7 +435,7 @@ class NCCLComm { } ncclResult_t checkForNcclError() { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; @@ -451,7 +450,7 @@ class NCCLComm { } ncclResult_t registerSegment(void* ptr, size_t size) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER // We register only segments from cache allocator // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always @@ -482,7 +481,7 @@ class NCCLComm { } ncclResult_t deregisterSegment(void* ptr) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); #ifdef NCCL_HAS_COMM_REGISTER TORCH_CHECK( registeredSegmentHandles_.count(ptr) == 1, @@ -519,7 +518,7 @@ class NCCLComm { bool aborted_; uint64_t ncclCommSplitCounter_{0}; ncclResult_t ncclAsyncErr_; - mutable std::timed_mutex mutex_; + mutable std::mutex mutex_; // Rank that this communicator corresponds to. int rank_; // Optional reason for communicator failure, provided by ProcessGroupNCCL for @@ -638,7 +637,7 @@ struct NCCLTraceBuffer { bool enabled_ = false; bool capture_cpp_stack_ = false; - std::timed_mutex mutex_; + std::mutex mutex_; std::vector entries_; size_t max_entries_ = 0; size_t next_ = 0; diff --git a/torch/csrc/distributed/c10d/NanCheck.cu b/torch/csrc/distributed/c10d/NanCheck.cu new file mode 100644 index 00000000000000..d256413d60a10f --- /dev/null +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -0,0 +1,256 @@ +#ifdef USE_C10D_NCCL + +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + +// CUDA kernel to check if data has NAN, device side assert +// is raised if NAN is found + +// Using ulong2 as a "byte pack", with 16 bytes, for efficient data load +union BytePack16 { + ulong2 ul2; + uint64_t ul[2]; +}; + +typedef union BytePack16 BytePack; + +// AMD HIP doesn't define `__trap()`, using `assert` instead +#ifdef USE_ROCM +#define __trap() assert(0) +#endif + +//// Start of templated functions for checking NaNs inside a BytePack + +// (i) General implementation (aka fallback) +// We use a for loop to iterate over the elements in a BytePack. +// EltPerPack would be greater than 8 if falling in this case. + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + #pragma unroll 8 + for (int i = 0; i < EltPerPack; i++) { + if (isnan(data[i])) __trap(); + } + } +}; + +// (ii) Template Specialization for 8-byte data types, e.g. double +// EltPerPack = 16 / 8 = 2 + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1])) __trap(); + } +}; + +// (iii) Template specialization for 4-byte data types, e.g. float32 +// EltPerPack = 16 / 4 = 4 + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) __trap(); + } +}; + +// (iv) Template specialization for 2-byte data types, e.g. float16, bfloat16, half. +// EltPerPack = 16 / 2 = 8 + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + T* data = (T*)tmp; + if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3]) || + isnan(data[4]) || isnan(data[5]) || isnan(data[6]) || isnan(data[7])) { + __trap(); + } + } +}; + +// (v) Template specialization for Float8 types. +// EltPerPack = 16 / 1 = 16 + +// We want to check 8 x FP8 simultaneously, hence this template definition. +template +struct HasNanFP8x8 { + static __device__ __forceinline__ bool check(uint64_t fp8x8) = delete; + /* + { + // `static_assert` in template definition requires c++23 onwards. + // But the error message still applies if you find yourself here. + static_assert( + false, + "You should never call this template definition because it is empty. You " + "can follow the example of Float8_e4m3fn below to implement the check for " + "your new datatype." + ); + } + */ +}; + +// isnan condition for Float8_e4m3fn: +// (x & 0b01111111) == 0b01111111 +// i.e. +// (x & 0x7f) == 0x7f + +// The algorithm is as follows: +// (1) Mask out the most significant bit with mask 0x7f. +// (2) If the result is 0x7f (is nan), the following arithmetic would cause the +// 8th bit to be 1: x[i] = x[i] + 0x01 +// (3) Only leave the 8th bit by masking with 0x80. +// (4) If any x[i] is nan, then the whole x != 0. + +template<> +struct HasNanFP8x8 { + static __device__ __forceinline__ bool check(uint64_t fp8x8) { + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; + auto incremented = t + 0x0101010101010101ULL; + auto overflow = incremented & 0x8080808080808080ULL; + return overflow != 0; + } +}; + +// isnan condition for Float8_e5m2: +// (x & 0x7f) > 0x7c +// This case does not overflow: 0x7c + 0x03 == 0x7f but adding 0x03 to anything +// greater than 0x7c will overflow. + +template<> +struct HasNanFP8x8 { + static __device__ __forceinline__ bool check(uint64_t fp8x8) { + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; + auto incremented = t + 0x0303030303030303ULL; + auto overflow = incremented & 0x8080808080808080ULL; + return overflow != 0; + } +}; + +template +struct CheckBytePack { + static __device__ __forceinline__ void check(BytePack* tmp) { + if (HasNanFP8x8::check(tmp->ul[0]) || HasNanFP8x8::check(tmp->ul[1])) + __trap(); + } +}; + +//// End of templated functions for checking NaNs inside a BytePack + + +// Fast-path check routine: +// each thread will load and check 8 BytePacks in this routine + +// Create a tmp buffer of size 8, also unroll for loop by 8 +#define UNROLL 8 + +template +__device__ __forceinline__ void checkChunk(BytePack* ptr) { + BytePack tmp[UNROLL]; + int nWorkers = blockDim.x * gridDim.x; + // First load values from global memory into tmp buffer + #pragma unroll 8 + for (int j = 0; j < UNROLL; j++) { + tmp[j] = ptr[nWorkers * j]; + } + // Then check each BytePack in the tmp buffer + #pragma unroll 8 + for (int j = 0; j < UNROLL; j++) { + CheckBytePack::check(tmp + j); + } + // Note: we separate the check from the load for efficient loading +} + +// Align address of `ptr` up, to the alignment of `T` +#define ALIGN_UP(ptr, T) (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) + +// This is the host-facing kernel + +template +__global__ void checkForNaN(T* data, size_t size) { + constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); + // Offset of current thread + size_t offset = blockIdx.x * blockDim.x + threadIdx.x; + + // Align input address up to BytePack in case it is not + T* ptrAlign = (T*)ALIGN_UP(data, BytePack); + // Pre-process the data before alignment + size_t preProcElts = min(ptrAlign - data, size); + // Read memory by T (slow). One iter is enough bc the number of threads would + // be bigger than `preProcElts` + if (offset < preProcElts) { + if (isnan(data[offset])) __trap(); + } + // We have processes this amount of data + size -= preProcElts; + + // Start BytePack processing + BytePack* ptr = (BytePack*)ptrAlign; + // Size of input data in unit of BytePack + size_t sizeInBP = size * sizeof(T) / sizeof(BytePack); + // Number of BytePacks processed in one fast-path iteration + size_t loopSize = blockDim.x * gridDim.x * UNROLL; + + // Fast path + // The condition below makes sure there is enough data to process (`loopSize`) + for (; offset + loopSize <= sizeInBP; offset += loopSize) { + checkChunk(ptr + offset); + } + + // The rest data goes on slow path + // We just do regular load and check + for (; offset < sizeInBP; offset += blockDim.x * gridDim.x) { + BytePack tmp = ptr[offset]; + CheckBytePack::check(&tmp); + } + + // We can still have a tail smaller than 1 BytePack + // TODO: merge this tail check with head check to make them concurrent + if (threadIdx.x < size % EltPerPack) { + T* tailPtr = (T*)(ptr + sizeInBP); + if (isnan(tailPtr[threadIdx.x])) __trap(); + } +} + +// CHECK if a Tensor contains NAN in any of its element +void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { + // skip check for non float types + if (!torch::is_floating_point(tensor)) { + return; + } + const size_t maxNumThreadsPerBlock = 512; + const size_t maxNumBlocks = 24; + const size_t numThreadsPerBlock = + std::min(maxNumThreadsPerBlock, tensor.numel()); + + const size_t numBlocks = std::min( + maxNumBlocks, + (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock); + + AT_DISPATCH_FLOATING_TYPES_AND4( + at::ScalarType::Half, + at::ScalarType::BFloat16, + at::ScalarType::Float8_e4m3fn, + at::ScalarType::Float8_e5m2, + tensor.scalar_type(), + "checkForNaN", + [&] { + checkForNaN<<>>( + tensor.data_ptr(), tensor.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +} // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/NanCheck.hpp b/torch/csrc/distributed/c10d/NanCheck.hpp new file mode 100644 index 00000000000000..cc9a5867c3dd40 --- /dev/null +++ b/torch/csrc/distributed/c10d/NanCheck.hpp @@ -0,0 +1,16 @@ +#pragma once + +#ifdef USE_C10D_NCCL + +#include +#include + +namespace c10d { + +// Check for NaNs in a tensor on a given stream. If any are found, throw a +// device-side error. +void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream); + +} // namespace c10d + +#endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 70356b3bf382ce..f565de2013260c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -14,24 +14,6 @@ namespace c10d { -static ProcessGroup::BackendType strToBackendType(std::string_view backend) { - if (backend == "undefined") { - return ProcessGroup::BackendType::UNDEFINED; - } else if (backend == "gloo") { - return ProcessGroup::BackendType::GLOO; - } else if (backend == "nccl") { - return ProcessGroup::BackendType::NCCL; - } else if (backend == "xccl") { - return ProcessGroup::BackendType::XCCL; - } else if (backend == "ucc") { - return ProcessGroup::BackendType::UCC; - } else if (backend == "mpi") { - return ProcessGroup::BackendType::MPI; - } else { - return ProcessGroup::BackendType::CUSTOM; - } -} - std::string opTypeToString(OpType opType) { switch (opType) { case OpType::BROADCAST: @@ -121,13 +103,11 @@ c10::intrusive_ptr ProcessGroup::getBackend( ProcessGroup::ProcessGroup( const c10::intrusive_ptr<::c10d::Store>& store, int rank, - int size, - c10::intrusive_ptr options) + int size) : store_(store), rank_(rank), size_(size), - options_(std::move(options)), - backendType_(strToBackendType(options_->backend)), + backendType_(BackendType::UNDEFINED), dist_debug_level_(debug_level()) { C10_LOG_API_USAGE_ONCE("c10d.process_group"); } diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 73fc2bda701327..83d2729fc43d43 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -45,24 +45,6 @@ namespace c10d { // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: - // ProcessGroup Options is a base struct that defines the basic options - // when constructing a ProcessGroup. Each ProcessGroup subclass should - // extend this struct and define its options if it wants to provide more - // config options (beyond basic ones defined here) to end user. - struct TORCH_API Options : torch::CustomClassHolder { - explicit Options( - std::string backend, - std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) - : timeout(timeout), backend(std::move(backend)) {} - ~Options() override = default; - - std::chrono::milliseconds timeout; - - // backend name - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const std::string backend; - }; - enum BackendType : uint8_t { UNDEFINED = 0, GLOO = 1, @@ -73,6 +55,45 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { XCCL = 6, }; + static std::string backendTypeToString(const BackendType& type) { + switch (type) { + case BackendType::GLOO: + return "gloo"; + case BackendType::NCCL: + return "nccl"; + case BackendType::XCCL: + return "xccl"; + case BackendType::UCC: + return "ucc"; + case BackendType::MPI: + return "mpi"; + case BackendType::UNDEFINED: + return "undefined"; + case BackendType::CUSTOM: + return "custom"; + default: + TORCH_CHECK(false, "THis should never happen!"); + } + }; + + static BackendType strToBackendType(const std::string& backend) { + if (backend == "undefined") { + return BackendType::UNDEFINED; + } else if (backend == "gloo") { + return BackendType::GLOO; + } else if (backend == "nccl") { + return BackendType::NCCL; + } else if (backend == "xccl") { + return BackendType::XCCL; + } else if (backend == "ucc") { + return BackendType::UCC; + } else if (backend == "mpi") { + return BackendType::MPI; + } else { + return BackendType::CUSTOM; + } + }; + // Not used, set for backwards compatibility and only used for TypeDef in // Ops.cpp explicit ProcessGroup(int rank, int size); @@ -80,8 +101,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { explicit ProcessGroup( const c10::intrusive_ptr<::c10d::Store>& store, int rank, - int size, - c10::intrusive_ptr options); + int size); ~ProcessGroup() override; int getRank() const { @@ -104,7 +124,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { } virtual const std::string getBackendName() const { - return options_->backend; + return backendTypeToString(backendType_); }; BackendType getBackendType() const { @@ -612,10 +632,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { opts.timeout.count()); } - c10::intrusive_ptr getOptions() { - return options_; - } - bool hasBackends() { return !deviceTypeToBackendType_.empty(); } @@ -656,6 +672,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendTypeToBackend_.at(backendType_); } + void setDefaultBackend(const BackendType& backendType) { + backendType_ = backendType; + } + + void setDefaultBackend(const std::string& backend) { + backendType_ = strToBackendType(backend); + } + c10::intrusive_ptr getBackend(c10::DeviceType deviceType); c10::intrusive_ptr getBackend(BackendType backendType) const { @@ -728,9 +752,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int size_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const c10::intrusive_ptr options_; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - const BackendType backendType_; + BackendType backendType_; std::string pg_desc_; // Debug level setting. It is parsed once when ProcessGroup is constructed and diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 96ea93f7413b6c..51fa248ec403b9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -602,7 +602,7 @@ uint64_t ProcessGroupGloo::RecvWork::getSequencenumber() const { } int ProcessGroupGloo::RecvWork::sourceRank() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); return srcRank_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1e19fc44b4e620..8a7aefdc238c4e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -24,13 +24,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include @@ -302,7 +302,7 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // hooks are called outside the scope of any PG, thus we need traverse // communicators in all PGs. static std::unordered_map, int> ncclCommDevIdxMap; -static std::timed_mutex ncclCommDevIdxMapMutex; +static std::mutex ncclCommDevIdxMapMutex; static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); @@ -315,7 +315,7 @@ void cacheAllocatorRegisterHook( return; } - C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); + std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -333,7 +333,7 @@ void cacheAllocatorDeregisterHook( return; } - C10D_LOCK_GUARD(lock, ncclCommDevIdxMapMutex); + std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& it : ncclCommDevIdxMap) { auto& ncclComm = it.first; auto& devIdx = it.second; @@ -422,22 +422,6 @@ std::future launchAsyncGilCheck() { return resultFuture; } -// Return CUDA device with ordinal given by input rank. If we aren't -// bound to a specific device, there is no strict guarantee that this -// heuristic is the correct assignment of ranks to GPUs that Python -// layers use, but in practice it tends to be. Fortunately we don't -// rely on this for correctness of any tensor operations, just for -// ancillary uses like barriers. -at::Device ProcessGroupNCCL::guessDeviceForRank() const { - TORCH_CHECK_WITH(ValueError, rank_ >= 0, "Invalid rank ", rank_); - if (getBoundDeviceId()) { - return *getBoundDeviceId(); - } else { - int16_t deviceIdx = static_cast(rank_ % localDeviceCount_); - return at::Device(at::DeviceType::CUDA, deviceIdx); - } -} - const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; constexpr int64_t kSynchronizeBusyWaitMillis = 10; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; @@ -552,7 +536,7 @@ void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { } auto exception_ptr = checkForNCCLErrors(); - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); exception_ = exception_ptr; if (exception_) { LOG(ERROR) << logPrefix() << "Collective " << *this @@ -568,7 +552,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { void ProcessGroupNCCL::WorkNCCL::setException( std::exception_ptr exception_ptr) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); exception_ = exception_ptr; } @@ -777,12 +761,12 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( bool timing) { auto deleter = [this, timing](at::cuda::CUDAEvent* event) { - C10D_LOCK_GUARD(lock, this->cacheMutex_); + std::lock_guard lock(this->cacheMutex_); this->eventsArray_[timing ? 1 : 0].push_back(event); }; at::cuda::CUDAEvent* event = nullptr; { - C10D_LOCK_GUARD(lock, cacheMutex_); + std::lock_guard lock(cacheMutex_); auto events = eventsArray_[timing ? 1 : 0]; if (!events.empty()) { event = events.back(); @@ -850,6 +834,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( // both timeout and other errors. dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || (dist_debug_level_ >= DebugLevel::Detail); + sleepAfterException_ = getCvarBool(TORCH_NCCL_SLEEP_AFTER_EXCEPTION, false); // logging C++ stack isn't safe. Introduce a variable to control it. logCppStackOnUncleanShutdown_ = getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); @@ -1087,9 +1072,8 @@ void ProcessGroupNCCL::waitForPendingWorks() { while (true) { { std::lock(workMetaListMutex_, completedWorkListMutex_); - std::lock_guard lockWork( - workMetaListMutex_, std::adopt_lock); - std::lock_guard lockHook( + std::lock_guard lockWork(workMetaListMutex_, std::adopt_lock); + std::lock_guard lockHook( completedWorkListMutex_, std::adopt_lock); if (workMetaList_.empty() && completedWorkList_.empty()) { @@ -1110,8 +1094,18 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, const std::string& futDescription, - bool throwException) { + bool throwException, + bool log) { std::string errorMsg; + + ::c10d::C10dLoggingData data; + if (log) { + data.integers["pg_id"] = local_id_; + data.integers["rank"] = rank_; + data.integers["global_rank"] = globalRank(); + data.strings["flight_recorder_version"] = c10d::version_val_str; + } + TORCH_CHECK(fut.valid(), "Expected a valid future"); std::future_status status = fut.wait_for(timeOutMilSec); if (status == std::future_status::ready) { @@ -1122,20 +1116,31 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( if (result) { LOG(INFO) << logPrefix() << "future is successfully executed for: " << futDescription; + if (log) { + data.strings["status"] = "SUCCESS"; + } } } catch (const std::exception& e) { errorMsg = c10::str( logPrefix(), - "Exception thrown when waitng for future ", + "Exception thrown when waiting for future ", futDescription, ": ", e.what()); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = e.what(); + } LOG(ERROR) << errorMsg; } catch (...) { errorMsg = c10::str( logPrefix(), - "Unknown exception thrown when waitng for future ", + "Unknown exception thrown when waiting for future ", futDescription); + if (log) { + data.strings["status"] = "EXCEPTION"; + data.strings["exception"] = "Unknown exception"; + } LOG(ERROR) << errorMsg; } } else { @@ -1146,8 +1151,15 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( " timed out after ", timeOutMilSec.count(), " ms"); + data.strings["status"] = "TIMEOUT"; LOG(ERROR) << errorMsg; } + if (log) { + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(data); + } + } if (throwException && !errorMsg.empty()) { C10_THROW_ERROR(DistBackendError, errorMsg); } @@ -1164,6 +1176,10 @@ void ProcessGroupNCCL::abortCommsFromMap( at::cuda::OptionalCUDAGuard gpuGuard; at::DeviceIndex deviceIndex = getIndexFromDeviceKey(devName); if (deviceIndex >= 0) { + // For P2P comms, the deviceIndex could be -1 (invalid), as the keys in + // the map could be non deviceIndex, but rank to rank numbers. So we + // indeed need to check if deviceIndex >= 0 + // TODO: fix `getIndexFromDeviceKey` or fix `DeviceKey` gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " @@ -1179,20 +1195,15 @@ void ProcessGroupNCCL::abortCommsFromMap( // their responsibility to destroy the process group and recreate // it to recover from errors. - c10::StreamId streamId = -1; - if (ncclStreams_.find(devName) != ncclStreams_.end()) { - auto stream = ncclStreams_.at(devName); - streamId = stream.id(); - } - LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed " - << " communicator on CUDA device: " << devName - << " with stream: " << streamId; + << " communicator on CUDA device: " << devName; } } // Abort all communicators on this rank bool ProcessGroupNCCL::abort(std::optional abortReason) { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); // Remove record from global ncclCommDevIdxMapMutex before aboarting, // so that a new cache segment would not register to already aborded // communicators. Note that ncclCommDevIdxMap is a global container which may @@ -1205,7 +1216,7 @@ bool ProcessGroupNCCL::abort(std::optional abortReason) { } ncclCommDevIdxMapMutex.unlock(); - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); abortCommsFromMap(devNCCLCommMap_, abortReason); abortCommsFromMap(inInitializationCommMap_, abortReason); return true; @@ -1224,7 +1235,8 @@ void ProcessGroupNCCL::shutdown(std::optional reason) { std::future fut = std::async( std::launch::async, [this, &reason]() { return this->abort(reason); }); - waitForFutureOrTimeout(fut, options_->timeout, "ProcessGroup abort", true); + waitForFutureOrTimeout( + fut, options_->timeout, "ProcessGroup abort", true, false); LOG(INFO) << logPrefix() << "ProcessGroupNCCL aborts successfully."; // We need to wait for abort to finish before we can safely shut down @@ -1274,8 +1286,8 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { // Serialize all calls to this function to avoid corrupting data, but allow // multiple calls in one runtime. User is responsible for preserving the // output file from an earlier call before a later call overwrites it. - static std::timed_mutex writeDebugInfoMutex; - C10D_LOCK_GUARD(lock, writeDebugInfoMutex); + static std::mutex writeDebugInfoMutex; + std::lock_guard lock(writeDebugInfoMutex); LOG(ERROR) << logPrefix() << "ProcessGroupNCCL preparing to dump debug info."; if (ncclTraceBufferSize_ > 0) { // We dump nccl trace into local disk by default and users can register @@ -1354,7 +1366,7 @@ void ProcessGroupNCCL::heartbeatMonitor() { // This won't have any lock since this lock is only used here. // Please be aware that mutex `monitorMutex_` should not be used // somewhere else to avoid the deadlock. - C10D_LOCK_GUARD(lock, monitorMutex_); + std::unique_lock lock(monitorMutex_); if (monitorWakeUpCV_.wait_for( lock, std::chrono::milliseconds(monitorPollInterval), [&] { return terminateHeartbeatMonitorThread_.load(); @@ -1486,11 +1498,13 @@ void ProcessGroupNCCL::heartbeatMonitor() { std::future asyncDebugDump = std::async( std::launch::async, [this]() { return this->dumpDebuggingInfo(); }); - // wait for the dump until timeout + // wait for the dump until timeout - log data waitForFutureOrTimeout( asyncDebugDump, std::chrono::milliseconds(waitTimeoutDumpInMilSec_), - "Flight recorder dump in heartbeatMonitor"); + "Flight recorder dump in heartbeatMonitor", + false, + true); } if (get_gil_checker() != nullptr) { @@ -1679,7 +1693,7 @@ const std::vector& ProcessGroupNCCL::groupRanks() const { void ProcessGroupNCCL::addEphemeralTimeout( const std::chrono::milliseconds& timeout) { - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); ephemeralTimeoutActive_ += timeout; } @@ -1702,7 +1716,7 @@ void ProcessGroupNCCL::watchdogHandler() { std::list completedWorkList; while (!done || !terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, workMetaListMutex_); + std::unique_lock lock(workMetaListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. workMetaListCV_.wait_for( @@ -1797,12 +1811,14 @@ void ProcessGroupNCCL::watchdogHandler() { } // signal the monitor thread on PG0 to start dumping shouldDump_.store(true); - // This sleep is used to give time for dumping before throwing - // exception - std::this_thread::sleep_for( - std::chrono::seconds(heartbeatTimeoutInSec_)); - LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ - << " giving time for flight recorder dumps to finish."; + if (sleepAfterException_) { + // This sleep is used to give time for dumping before throwing + // exception + std::this_thread::sleep_for( + std::chrono::seconds(heartbeatTimeoutInSec_)); + LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + << " giving time for flight recorder dumps to finish."; + } } catch (const std::exception& e) { LOG(ERROR) << logPrefix() << "Failed to set dump signal in tcpstore. " @@ -1874,7 +1890,7 @@ void ProcessGroupNCCL::watchdogHandler() { if (work.isCompleted()) { { // Reset the timeout and first work if the work is completed. - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); if (work.ownedEphermeralTimeout_.count() > 0) { ephemeralTimeoutActive_ -= work.ownedEphermeralTimeout_; ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; @@ -1889,7 +1905,7 @@ void ProcessGroupNCCL::watchdogHandler() { // Move Work object to completedWorkList_ to be consumed by the hook // thread { - C10D_LOCK_GUARD(lock, completedWorkListMutex_); + const std::lock_guard lock(completedWorkListMutex_); completedWorkList_.splice( completedWorkList_.end(), workMetaList_, it++); } @@ -1917,7 +1933,7 @@ void ProcessGroupNCCL::runHookLoop() { bool done = false; while (!done || !terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, completedWorkListMutex_); + std::unique_lock lock(completedWorkListMutex_); // We busy-poll the work vector every kWatchdogThreadSleepMillis // milliseconds as long as the atomic is True. completedWorkListCV_.wait_for( @@ -2090,7 +2106,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( } void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { TORCH_INTERNAL_ASSERT( false, @@ -2138,7 +2154,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( usedDeviceIdxs_.insert(device.index()); { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); if (devNCCLCommMap_.find(deviceKey) != devNCCLCommMap_.end()) { // Reuse the cached communicator if there is one. return devNCCLCommMap_[deviceKey]; @@ -2164,7 +2180,9 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( bool batchP2P = ncclActiveGroupCounter_ > 0; bool singleP2POp = isP2POp(opType, batchP2P); - at::cuda::OptionalCUDAGuard gpuGuard; + // Get the device index + auto deviceIndex = device.index(); + at::cuda::OptionalCUDAGuard gpuGuard(device); // [Group Start/End Note] This is used to ensure that nccl communicator will // be created before communication primitives are called. Let's look at this @@ -2204,17 +2222,13 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } - // Get the device index - auto deviceIndex = device.index(); - gpuGuard.set_index(deviceIndex); - #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { TORCH_CHECK( options_->split_color != 0, "Must specify a non-zero color when splitting"); // Find a valid, healthy communicator to split from if possible. - C10D_LOCK_GUARD(lock, options_->split_from->mutex_); + std::lock_guard lock(options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; auto dit = other_comms.find(getKeyFromDevice(device)); if (dit != other_comms.end()) { @@ -2268,7 +2282,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( options_->is_high_priority_stream || force_high); { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); inInitializationCommMap_.emplace(deviceKey, ncclComm); } @@ -2518,7 +2532,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( const c10::intrusive_ptr& work, const c10::intrusive_ptr& option) { std::chrono::milliseconds timeout = option->timeout; - C10D_LOCK_GUARD(timeoutLock, mtxTimeoutExtension_); + std::lock_guard timeoutLock(mtxTimeoutExtension_); if (ephemeralTimeoutActive_.count() > 0) { timeout += ephemeralTimeoutActive_; } @@ -2531,7 +2545,7 @@ void ProcessGroupNCCL::assignTimeoutToWork( void ProcessGroupNCCL::workEnqueue( c10::intrusive_ptr work) { if (!terminateProcessGroup_.load()) { - C10D_LOCK_GUARD(lock, workMetaListMutex_); + std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. // View tensors' destruction invokes autograd_meta, which // needs to be destructed in user thread. Otherwise will @@ -2653,19 +2667,19 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { template c10::intrusive_ptr ProcessGroupNCCL::collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { - if (enableNanCheck_) { - checkForNan(input); - } + bool avoidRecordStreams, + bool nanCheck) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + nanCheck &= enableNanCheck_; + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2674,7 +2688,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( seqCollective_++; op_id_++; - auto device = getDevice(input); + auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -2699,25 +2713,26 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // First let NCCL streams wait for input tensors allocation streams syncStream(device, ncclEvents_[key], ncclStream); - std::vector inputs{input}; - std::vector outputs{output}; - bool enqueue = !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; auto work = initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue); // Store references to outputs to be used by WorkNCCL::result and operator<<. - work->outputs_ = - std::make_shared>(std::move(outputs)); + work->outputs_ = std::make_shared>(outputs); if (avoidRecordStreams) { work->stashed_for_allocator_safety_ = - std::make_shared>(); - work->stashed_for_allocator_safety_->push_back(input); + std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); + + if (nanCheck) { + for (const auto& input : inputs) { + checkForNan(input, ncclStream); + } + } // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { @@ -2737,25 +2752,35 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // // See [Sync Streams]. if (!avoidRecordStreams) { - if (!input.is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - input.values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - input.indices().storage().data_ptr(), ncclStream); + for (const auto& input : inputs) { + if (!input.is_sparse()) { + c10::cuda::CUDACachingAllocator::recordStream( + input.storage().data_ptr(), ncclStream); + } else { + // for sparse input case record streams on both index and value + // tensors + c10::cuda::CUDACachingAllocator::recordStream( + input.values().storage().data_ptr(), ncclStream); + c10::cuda::CUDACachingAllocator::recordStream( + input.indices().storage().data_ptr(), ncclStream); + } } } + +// Not all collectives have the same signature, e.g, all-reduce take in a Tensor +// as the input and output while all-to-all take in a vector of Tensors as input +// and output. Because we define the signature of the fn to take only single +// tensor as input and output, we need to do a hack to get the first element in +// the vector and pass it to fn. +// TODO: we should clean up this in future (by either entirely removing lambda's +// or removing input and output from lambda's signature). #ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( - fn(input, output, comm, ncclStream), + fn(inputs[0], outputs[0], comm, ncclStream), ncclComm->getNcclCommFailureReason()); #else C10D_NCCL_CHECK_TIMEOUT( - fn(input, output, comm, ncclStream), + fn(inputs[0], outputs[0], comm, ncclStream), comm, ncclComm->getNcclCommFailureReason()); #endif @@ -2797,8 +2822,14 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as // multi-device per process is deprecated - work->numelIn_ = input.numel(); - work->numelOut_ = output.numel(); + work->numelIn_ = 0; + work->numelOut_ = 0; + for (const auto& input : inputs) { + work->numelIn_ += input.numel(); + } + for (const auto& output : outputs) { + work->numelOut_ += output.numel(); + } // Notify graphs before we check the capture status preemptively at::cuda::CUDAGraph::inc_pending_event_queries(); @@ -2878,7 +2909,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) @@ -3020,9 +3051,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( PreProcess pre, PostProcess post, const char* profilingTitle) { - if (enableNanCheck_) { - checkForNan(tensor); - } // avoidRecordStreams_ note: // send, recv, and irecv should be ok with avoidRecordStreams, // However, for isend, I don't think the API requires the user @@ -3149,7 +3177,13 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard; + at::cuda::OptionalCUDAGuard gpuGuard(device); + + // Only check for NaN for send ops, for recv ops `tensor` can be a random + // placeholder + if (enableNanCheck_ && opType == OpType::SEND) { + checkForNan(tensor, ncclStream); + } if (!coalescing_state_) { // Start event should only be recorded before the ncclGroupStart() @@ -3237,6 +3271,31 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } } +template +c10::intrusive_ptr ProcessGroupNCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle, + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + pre, + post, + opType, + profilingTitle, + avoidRecordStreams, + nanCheck); +} + template c10::intrusive_ptr ProcessGroupNCCL::collective( at::Tensor& input, @@ -3244,10 +3303,13 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( Fn fn, OpType opType, const char* profilingTitle, - bool avoidRecordStreams) { + bool avoidRecordStreams, + bool nanCheck) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; return collective( - input, - output, + inputs, + outputs, fn, [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, @@ -3255,7 +3317,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( c10::intrusive_ptr& work) {}, opType, profilingTitle, - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } template @@ -3505,6 +3568,9 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // avoidRecordStreams_ note: collective() will stash tensors. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); + return collective( tensor, tensor, @@ -3512,7 +3578,6 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBcast( input.data_ptr(), input.numel(), @@ -3523,7 +3588,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( }, OpType::BROADCAST, "nccl:broadcast", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } // _broadcast_oop adds an out-of-place broadcast in PGNCCL @@ -3542,7 +3608,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( ValueError, "Tensor input and output of _broadcast_oop must have the same number of elements "); } - + const auto root = opts.rootRank + opts.rootTensor; + bool nanCheck = (root == rank_); return collective( inputTensor, outputTensor, @@ -3550,7 +3617,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank + opts.rootTensor; return ncclBroadcast( input.data_ptr(), output.data_ptr(), @@ -3561,7 +3627,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, - "nccl:_broadcast_oop"); + "nccl:_broadcast_oop", + /*avoidRecordStreams=*/false, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::reduce( @@ -3638,7 +3706,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( ValueError, "Tensor input and output of _reduce_oop must have the same number of elements "); } - return collective( inputTensor, outputTensor, @@ -4025,40 +4092,54 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { globalRankStride, // globalRankStride this->getSize()); // worldSize - std::vector devices; + // Device to use for barrier + int barDevIdx = -1; - // Use user defined GPU device ids if provided + // Select device to use for barrier + // 1st choice: Use user defined GPU device ids if provided if (!opts.device_ids.empty()) { - for (auto device : opts.device_ids) { - devices.emplace_back(at::DeviceType::CUDA, device); - } - } else if (usedDeviceIdxs_.empty()) { + // Use the first device id because PG NCCL is single-device now + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + // 2nd choice: Use the bound GPU device id if available. + // Bounded device id can be passed to `init_process_group`. + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + // 3rd choice: infer the device id from the used device ids. + barDevIdx = *usedDeviceIdxs_.begin(); + } else { // This means there is not yet a NCCL collective being called // Here we have to use the best guesses and will use a single GPU to call // allreduce to achieve barrier. // In case the multiple processes fall into the same node, we use rank to // ensure that each process is on a different GPU - auto numGPUs = at::cuda::getNumGPUs(); - int16_t deviceIdx = static_cast(rank_ % numGPUs); - LOG(INFO) + // Note: it is better to use global rank because the group-local rank can be + // offset wrt the device id if intra-node GPUs are sharded into multiple + // dimensions. + barDevIdx = static_cast(globalRank() % localDeviceCount_); + LOG(WARNING) << logPrefix() << c10::str( " using GPU ", - deviceIdx, + barDevIdx, " to perform barrier as devices used by this process are currently unknown. ", "This can potentially cause a hang if this rank to GPU mapping is incorrect.", - "Specify device_ids in barrier() to force use of a particular device."); - devices.emplace_back(guessDeviceForRank()); - } else { - for (auto usedDeviceIdx : usedDeviceIdxs_) { - devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); - } + "Specify device_ids in barrier() to force use of a particular device,", + "or call init_process_group() with a device_id."); } - // Use one device only - auto device = devices.back(); + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); + + // Create a dummy tensor on the device + // Note: we use zeros() instead of empty() to prevent barrier from triggering + // alarm when NaN checker is enabled. at::Tensor barrierTensor = - at::empty({1}, at::TensorOptions().device(device).dtype(at::kFloat)); + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + // All reduce to achieve the barrier auto work = allreduce_impl(barrierTensor); @@ -4217,8 +4298,8 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( this->getSize()); // worldSize return collective( - inputTensors[0], - outputTensors[0], + inputTensors, + outputTensors, [&](at::Tensor& /* unused */, at::Tensor& /* unused */, ncclComm_t comm, @@ -4411,9 +4492,11 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( // avoidRecordStreams_ note: collective() will stash inputTensors and // outputs, which == outputTensors[0] on the root rank where it matters. + + auto inputs = std::vector{inputTensor}; return collective( - inputTensor, - outputs[0], // just to fit the collective interface + inputs, + outputs, // just to fit the collective interface [&](at::Tensor& /* unused */, at::Tensor& /* unused */, ncclComm_t comm, @@ -4430,6 +4513,10 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root); return ncclSuccess; }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, OpType::GATHER, "nccl:gather"); } @@ -4500,14 +4587,17 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // inputs, which == inputTensors[0] on the root rank where it matters. bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); + const auto root = opts.rootRank; + bool nanCheck = (rank_ == root); + + auto outputs = std::vector{outputTensor}; return collective( - outputTensor, - inputs[0], // just to fit the collective interface + outputs, + inputs, // just to fit the collective interface [&](at::Tensor& /* unused */, at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { - const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams) { for (auto input : inputs) { @@ -4519,9 +4609,14 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root); return ncclSuccess; }, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, + [](at::cuda::CUDAStream&, + c10::intrusive_ptr& work) {}, OpType::SCATTER, "nccl:scatter", - avoidRecordStreams); + avoidRecordStreams, + nanCheck); } c10::intrusive_ptr ProcessGroupNCCL::recvAnysource( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ba68cda9cc3ea..2bca8992af445c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -63,6 +63,13 @@ static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { static std::vector TORCH_NCCL_DUMP_ON_TIMEOUT = { "TORCH_NCCL_DUMP_ON_TIMEOUT"}; +// TODO: remove this change after a safe rollout. +// Control whether we sleep after an exception is thrown. +// This change is temporary and is used to safely remove the current sleep that +// exists after an exception is thrown. +static std::vector TORCH_NCCL_SLEEP_AFTER_EXCEPTION = { + "TORCH_NCCL_SLEEP_AFTER_EXCEPTION"}; + // Control whether Desync Debug is enabled. This variable must be set // together with TORCH_NCCL_ASYNC_ERROR_HANDLING. static std::vector TORCH_NCCL_DESYNC_DEBUG = { @@ -178,7 +185,7 @@ struct DumpPipe { getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) { return; } - TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty"); + TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_PIPE_FILE is empty"); std::string filename = c10::str(fileStem, rank, ".pipe"); TORCH_CHECK( unlink(filename.c_str()) != -1 || errno == ENOENT, @@ -449,7 +456,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { static CUDAEventCache& get(); private: - std::timed_mutex cacheMutex_; + std::mutex cacheMutex_; // NOTE: We intentionaly store raw pointers so that // we do not attempt to destroy the event objects on process exit, // because cuda may be gone. @@ -762,7 +769,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collective( @@ -773,7 +781,20 @@ class TORCH_API ProcessGroupNCCL : public Backend { PostProcess post, OpType opType, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool avoidRecordStreams = false, + bool nanCheck = true); + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr, + bool avoidRecordStreams = false, + bool nanCheck = true); template c10::intrusive_ptr collectiveCoalesced( @@ -888,7 +909,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::future& fut, const std::chrono::milliseconds& timeOutMilSec, const std::string& futDescription, - bool throwException = false); + bool throwException = false, + bool log = false); // When watchdog timeout, this function will be called and return debug info // for users. For now we only get information from retrieveDesyncReport. @@ -918,7 +940,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // ephemeralTimeoutActive_/ephemeralTimeoutInflight_. // TODO(fduwjj): We need to have an audit on all mutexes we are adding here. // And consolidate them if possible. - std::timed_mutex mtxTimeoutExtension_; + std::mutex mtxTimeoutExtension_; // The ephemeral timeout added on top of existing timeout for works issued // before first work finishes. @@ -978,7 +1000,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { inInitializationCommMap_; // Mutex to guard maps like devNCCLCommMap_. - std::timed_mutex mutex_; + std::mutex mutex_; // Heartbeat of watchdog thread. std::atomic_uint64_t heartbeat_; @@ -1039,18 +1061,18 @@ class TORCH_API ProcessGroupNCCL : public Backend { static std::atomic shouldDump_; // Mutex to Guard workMetaList_ - std::timed_mutex workMetaListMutex_; + std::mutex workMetaListMutex_; // Mutex to Guard monitorWakeUpCV_ - std::timed_mutex monitorMutex_; + std::mutex monitorMutex_; bool writeDebugInfo_ = false; // Condition Variable for watchdog thread sleep - std::condition_variable_any workMetaListCV_; + std::condition_variable workMetaListCV_; // Condition Variable for monitor thread to wake up early - std::condition_variable_any monitorWakeUpCV_; + std::condition_variable monitorWakeUpCV_; // Vector to Store WorkNCCL pointers std::list workMetaList_; @@ -1058,10 +1080,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::chrono::time_point lastWorkListUpdateTime_; // Mutex to Guard workMetaList_ - std::timed_mutex completedWorkListMutex_; + std::mutex completedWorkListMutex_; // Condition Variable for watchdog thread sleep - std::condition_variable_any completedWorkListCV_; + std::condition_variable completedWorkListCV_; std::list completedWorkList_; @@ -1125,6 +1147,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // timeout and nccl errors. bool dumpOnTimeoutOrEx_; + // Whether or not to sleep after an exception is thrown in the watchdog. + bool sleepAfterException_; + // Whether or not to enable nan check for input tensors to collectives. bool enableNanCheck_; diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 265c78f1b78cf9..3655984d452a92 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -210,6 +210,7 @@ class TORCH_PYTHON_API PythonOnCompletionHook { // Wraps a py::object hook and acquires Python GIL in dtor before // destructing the hook object. PythonOnCompletionHook(py::object hook) : hook_(std::move(hook)) {} + PythonOnCompletionHook(const PythonOnCompletionHook&) = default; ~PythonOnCompletionHook() { py::gil_scoped_acquire ag; diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 4bd3b28ec1e24c..68c5da982c2573 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -33,53 +34,6 @@ namespace c10d { namespace detail { -class timing_guard { - Counter& counter_; - typedef std::chrono::time_point - time_point; - time_point start_; - - public: - timing_guard(Counter& counter) - : counter_(counter), start_(std::chrono::high_resolution_clock::now()) {} - - ~timing_guard() { - stop(); - } - - void stop() { - if (start_ != time_point()) { - auto diff = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_) - .count(); - counter_.update(diff); - start_ = time_point(); - } - } -}; - -void Counter::update(double val) { - count_ += 1; - - auto delta = val - mean_; - mean_ += delta / count_; - - auto delta2 = val - mean_; - m2_ += delta2 * delta2; -} - -std::unordered_map Counter::observe() const { - std::unordered_map res; - res["count"] = (double)count_; - res["mean"] = mean_; - if (count_ >= 2) { - res["sample_variance"] = m2_ / (count_ - 1); - } else { - res["sample_variance"] = std::nan("1"); - } - return res; -} - // Manages the lifecycle of a server daemon. class TCPServer { public: @@ -295,28 +249,17 @@ class SendBuffer { using detail::Socket; // TCPStore class methods -TCPStore::TCPStore( - const std::string& masterAddr, - std::uint16_t masterPort, - std::optional numWorkers, - bool isServer, - const std::chrono::milliseconds& timeout, - bool waitWorkers) - : TCPStore{ - masterAddr, - TCPStoreOptions{ - masterPort, - isServer, - numWorkers ? std::optional(*numWorkers) - : std::nullopt, - waitWorkers, - timeout}} {} +// Although we still allow multi-params in ctor in Python, that behavior is +// removed from cpp and we construct the opts implicitly for users in the pybind +// of TCPStore. TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) : Store{opts.timeout}, addr_{std::move(host)}, numWorkers_{opts.numWorkers}, usingLibUv_{opts.useLibUV} { + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__init); + if (opts.useLibUV) { TORCH_CHECK( ::c10d::detail::is_libuv_tcpstore_backend_available(), @@ -421,7 +364,7 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) TCPStore::~TCPStore() = default; void TCPStore::waitForWorkers() { - detail::timing_guard tguard(clientCounters_["waitForWorkers"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers); if (numWorkers_ == std::nullopt) { return; } @@ -491,7 +434,7 @@ void TCPStore::_splitSet( } void TCPStore::set(const std::string& key, const std::vector& data) { - detail::timing_guard tguard(clientCounters_["set"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__set); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::SET); buffer.appendString(keyPrefix_ + key); @@ -503,7 +446,7 @@ std::vector TCPStore::compareSet( const std::string& key, const std::vector& expectedValue, const std::vector& desiredValue) { - detail::timing_guard tguard(clientCounters_["compareSet"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__compareSet); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET); buffer.appendString(keyPrefix_ + key); @@ -515,7 +458,7 @@ std::vector TCPStore::compareSet( } std::vector TCPStore::get(const std::string& key) { - detail::timing_guard tguard(clientCounters_["get"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__get); const std::lock_guard lock(activeOpLock_); return doGet(keyPrefix_ + key); } @@ -530,13 +473,13 @@ std::vector TCPStore::doGet(const std::string& key) { } int64_t TCPStore::add(const std::string& key, int64_t value) { - detail::timing_guard tguard(clientCounters_["add"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__add); const std::lock_guard lock(activeOpLock_); return incrementValueBy(keyPrefix_ + key, value); } bool TCPStore::deleteKey(const std::string& key) { - detail::timing_guard tguard(clientCounters_["deleteKey"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__delete); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY); buffer.appendString(keyPrefix_ + key); @@ -564,7 +507,7 @@ int64_t TCPStore::getNumKeys() { } bool TCPStore::check(const std::vector& keys) { - detail::timing_guard tguard(clientCounters_["check"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__check); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::CHECK); buffer.appendValue(keys.size()); @@ -591,7 +534,7 @@ void TCPStore::wait(const std::vector& keys) { void TCPStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { - detail::timing_guard tguard(clientCounters_["wait"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__wait); const std::lock_guard lock(activeOpLock_); std::vector prefixedKeys{}; prefixedKeys.reserve(keys.size()); @@ -652,7 +595,7 @@ void TCPStore::doWait( void TCPStore::append( const std::string& key, const std::vector& data) { - detail::timing_guard tguard(clientCounters_["append"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__append); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::APPEND); buffer.appendString(keyPrefix_ + key); @@ -662,7 +605,7 @@ void TCPStore::append( std::vector> TCPStore::multiGet( const std::vector& keys) { - detail::timing_guard tguard(clientCounters_["multiGet"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiGet); const std::lock_guard lock(activeOpLock_); std::vector prefixedKeys; prefixedKeys.reserve(keys.size()); @@ -689,7 +632,7 @@ std::vector> TCPStore::multiGet( void TCPStore::multiSet( const std::vector& keys, const std::vector>& values) { - detail::timing_guard tguard(clientCounters_["multiSet"]); + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiSet); TORCH_CHECK( keys.size() == values.size(), "multiSet keys and values vectors must be of same size"); @@ -708,15 +651,6 @@ bool TCPStore::hasExtendedApi() const { return true; } -std::unordered_map> -TCPStore::collectClientCounters() const noexcept { - std::unordered_map> res; - for (const auto& kv : clientCounters_) { - res[kv.first] = kv.second.observe(); - } - return res; -} - std::string TCPStore::repr() const { auto clientRepr = client_ ? client_->repr() : ""; auto serverRepr = server_ ? server_->repr() : ""; diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 015d134e983fe3..37862fc25fac33 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -9,6 +9,33 @@ namespace c10d { namespace detail { +// TCPStore is a key-value store used by PyTorch mainly for distributed +// rendezvous, but for other purposes as well. (e.g., a centralized storage for +// synchronization among different processes.) +// +// It is run via a classic client-server architecture, where the server runs +// a separate background thread (alternatively we call it daemon thread). The +// client and server communicate via TCP sockets. +// +// Currently we have two types of server backends: +// 1. TCPStoreBackend: a single thread to handle all incoming request +// synchronously. +// 2. LibUVTCPStoreBackend: an event-driven asynchronous stream processing that +// leverages libuv library (https://github.com/libuv/libuv) for better +// performance. And this backend now is recommended to users. (We set the +// default value of `useLibUV` inside `TCPStoreOptions` to true now, so users +// should get it by default). +// +// Code structure: +// ├── TCPStore client side API and server setup code: +// │ TCPStore.hpp/TCPStore.cpp +// ├── TCPStoreBackend server side API implementation code: +// │ TCPStoreBackend.hpp/TCPStoreBackend.cpp +// | (actual class:`TCPStoreMasterDaemon`) +// ├── LibUVTCPStoreBackend +// │ TCPStoreLibUvBackend.cpp +// | (actual class: `LibUVStoreDaemon`) + class TCPServer; class TCPClient; @@ -18,30 +45,6 @@ struct SocketAddress { std::uint16_t port{}; }; -class Counter { - public: - void update(double val); - std::unordered_map observe() const; - - double mean() const noexcept { - return mean_; - } - int64_t count() const noexcept { - return count_; - } - double variance() const noexcept { - return m2_ / static_cast(count_); - } - double sample_variance() const noexcept { - return m2_ / static_cast(count_ - 1); - } - - private: - int64_t count_ = 0; - double mean_ = 0; - double m2_ = 0; -}; - } // namespace detail struct TCPStoreOptions { @@ -72,14 +75,6 @@ class TORCH_API TCPStore : public Store { explicit TCPStore(std::string host, const TCPStoreOptions& opts = {}); - [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( - const std::string& masterAddr, - std::uint16_t masterPort, - std::optional numWorkers = std::nullopt, - bool isServer = false, - const std::chrono::milliseconds& timeout = kDefaultTimeout, - bool waitWorkers = true); - ~TCPStore() override; void set(const std::string& key, const std::vector& value) override; @@ -130,9 +125,6 @@ class TORCH_API TCPStore : public Store { return addr_.port; } - std::unordered_map> - collectClientCounters() const noexcept; - bool isLibUvBackend() const noexcept { return usingLibUv_; } @@ -162,7 +154,6 @@ class TORCH_API TCPStore : public Store { const std::string initKey_ = "init/"; const std::string keyPrefix_ = "/"; std::mutex activeOpLock_; - std::unordered_map clientCounters_; bool usingLibUv_ = true; }; diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 0cac94bcd13dba..c3fa09ab38bef1 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #ifdef TORCH_USE_LIBUV #include @@ -89,6 +90,7 @@ class UvHandle : public c10::intrusive_ptr_target { class UvTcpSocket : public UvHandle { uv_tcp_t client{}; + std::string address_{"unknown"}; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); @@ -100,7 +102,6 @@ class UvTcpSocket : public UvHandle { } static void alloc_buffer( - uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) { @@ -144,7 +145,24 @@ class UvTcpSocket : public UvHandle { } } + const std::string& address() const { + return address_; + } + void startRead() { + struct ::sockaddr_storage addr {}; + int addrLen{sizeof(struct ::sockaddr_storage)}; + + if (int err = uv_tcp_getpeername( + &client, reinterpret_cast(&addr), &addrLen)) { + C10D_WARNING( + "The remote address of the client socket cannot be retrieved. err={}", + uv_strerror(err)); + } else { + address_ = + formatSockAddr(reinterpret_cast(&addr), addrLen); + } + int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback); if (res) { C10D_WARNING( @@ -185,7 +203,7 @@ class UvTcpServer : public UvTcpSocket { public: typedef std::function OnConnectCallback; explicit UvTcpServer(uv_loop_t* loop) - : UvTcpSocket(loop), onConnectCb(missingOnConnect) {} + : UvTcpSocket(loop), onConnectCb_(missingOnConnect) {} static c10::intrusive_ptr makeWithSocket( uv_loop_t* loop, @@ -216,7 +234,7 @@ class UvTcpServer : public UvTcpSocket { } void setOnConnectCallback(OnConnectCallback&& callback) { - onConnectCb = std::move(callback); + onConnectCb_ = std::move(callback); } static c10::intrusive_ptr makeWithPort( @@ -247,8 +265,8 @@ class UvTcpServer : public UvTcpSocket { ", message: ", uv_strerror(uv_res)); - uv_res = - uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0); + uv_res = uv_tcp_bind( + res->unsafeGetSocket(), (const struct ::sockaddr*)&addr, 0); TORCH_CHECK( uv_res == 0, "The server socket has failed to bind. ", @@ -289,7 +307,7 @@ class UvTcpServer : public UvTcpSocket { } uint16_t port() const { - return portNum; + return portNum_; } void accept(const c10::intrusive_ptr& socket) { @@ -307,8 +325,8 @@ class UvTcpServer : public UvTcpSocket { } private: - OnConnectCallback onConnectCb; - uint16_t portNum{}; + OnConnectCallback onConnectCb_; + uint16_t portNum_{}; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); @@ -326,16 +344,16 @@ class UvTcpServer : public UvTcpSocket { if (uv_tcp_getsockname( (uv_tcp_t*)unsafeGetStream(), - reinterpret_cast(&addr_s), + reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != 0) { throw std::runtime_error( "The port number of the socket cannot be retrieved."); } if (addr_s.ss_family == AF_INET) { - portNum = ntohs(reinterpret_cast(&addr_s)->sin_port); + portNum_ = ntohs(reinterpret_cast(&addr_s)->sin_port); } else { - portNum = ntohs(reinterpret_cast(&addr_s)->sin6_port); + portNum_ = ntohs(reinterpret_cast(&addr_s)->sin6_port); } } @@ -344,7 +362,7 @@ class UvTcpServer : public UvTcpSocket { } static void on_new_connection(uv_stream_t* server, int status) { - borrow(server)->onConnectCb(status); + borrow(server)->onConnectCb_(status); } }; @@ -628,10 +646,10 @@ class LibUVStoreDaemon : public BackgroundThread { void stop() override; private: - uv_loop_t loop{}; - c10::intrusive_ptr tcpServer; + uv_loop_t loop_{}; + c10::intrusive_ptr tcpServer_; - uv_async_t exit_handle{}; + uv_async_t exit_handle_{}; std::unordered_map> tcpStore_; // From key -> the list of UvClient waiting on the key std::unordered_map>> @@ -754,6 +772,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_value(validateNumber)) return false; + C10D_TRACE("validate magic:{} address:{}", validateNumber, this->address()); + if (validateNumber != c10d::detail::validationMagicNumber) return false; return true; @@ -765,6 +785,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("ping nonce:{} address:{}", nonce, this->address()); + StreamWriter sw(iptr()); sw.write_value(nonce); sw.send(); @@ -780,6 +802,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_payload(newData)) return false; + C10D_TRACE("set key:{} address:{}", key, this->address()); + store->set(key, newData); return true; } @@ -797,6 +821,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_payload(newValue)) return false; + C10D_TRACE("compareAndSet key:{} address:{}", key, this->address()); + auto res = store->compareAndSet(key, currentValue, newValue); StreamWriter sw(iptr()); sw.write_vector(res); @@ -810,6 +836,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_key(key)) return false; + C10D_TRACE("get key:{} address:{}", key, this->address()); + const auto& data = store->get(key); StreamWriter sw(iptr()); sw.write_vector(data); @@ -826,6 +854,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_value(addVal)) return false; + C10D_TRACE("add key:{} val:{} address:{}", key, addVal, this->address()); + addVal = store->add(key, addVal); StreamWriter sw(iptr()); sw.write_value(addVal); @@ -852,6 +882,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("check key_count:{} address:{}", key_count, this->address()); + // Now we have received all the keys StreamWriter sw(iptr()); if (store->checkKeys(keys)) { @@ -882,6 +914,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("wait key_count:{} address:{}", key_count, this->address()); + if (store->waitKeys(keys, iptr())) { StreamWriter sw(iptr()); sw.write1((uint8_t)WaitResponseType::STOP_WAITING); @@ -892,6 +926,8 @@ class UvClient : public UvTcpSocket { } bool parse_getnumkeys_command() { + C10D_TRACE("getnumkeys address:{}", this->address()); + StreamWriter sw(iptr()); sw.write_value(store->size()); sw.send(); @@ -904,6 +940,8 @@ class UvClient : public UvTcpSocket { if (!stream.read_key(key)) return false; + C10D_TRACE("delete key:{} address:{}", key, this->address()); + auto numDeleted = store->deleteKey(key); StreamWriter sw(iptr()); sw.write_value(numDeleted); @@ -923,6 +961,8 @@ class UvClient : public UvTcpSocket { return false; } + C10D_TRACE("append key:{} address:{}", key, this->address()); + store->append(key, data); return true; } @@ -940,6 +980,8 @@ class UvClient : public UvTcpSocket { ", max: ", MAX_KEY_COUNT); + C10D_TRACE("multi_get key_count:{} address:{}", key_count, this->address()); + StreamWriter sw(iptr()); for (const auto _ : c10::irange(key_count)) { (void)_; // Suppress unused variable warning @@ -968,6 +1010,8 @@ class UvClient : public UvTcpSocket { ", max: ", MAX_KEY_COUNT); + C10D_TRACE("multi_set key_count:{} address:{}", key_count, this->address()); + for (const auto _ : c10::irange(key_count)) { (void)_; // Suppress unused variable warning @@ -988,6 +1032,8 @@ class UvClient : public UvTcpSocket { bool parse_cancel_wait_command() { store->clearClientWaitState(iptr()); + C10D_TRACE("cancel_wait key_count:{} address:{}", this->address()); + StreamWriter sw(iptr()); sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED); sw.send(); @@ -1018,10 +1064,10 @@ class UvClient : public UvTcpSocket { }; void LibUVStoreDaemon::onConnect(int status) { - auto client = UvClient::make(&loop, this); + auto client = UvClient::make(&loop_, this); registerClient(client); try { - tcpServer->accept(client); + tcpServer_->accept(client); client->startRead(); } catch (std::exception& e) { C10D_WARNING("Failed to accept client due to {}", e.what()); @@ -1031,27 +1077,28 @@ void LibUVStoreDaemon::onConnect(int status) { void LibUVStoreDaemon::onExitRequest() { C10D_DEBUG("Store exit requested\n"); - uv_close((uv_handle_t*)&exit_handle, nullptr); - uv_stop(&loop); + uv_close((uv_handle_t*)&exit_handle_, nullptr); + uv_stop(&loop_); } void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { if (opts.masterListenFd.has_value()) { - tcpServer = UvTcpServer::makeWithSocket(&loop, *opts.masterListenFd); + tcpServer_ = UvTcpServer::makeWithSocket(&loop_, *opts.masterListenFd); } else { try { - tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/true); + tcpServer_ = + UvTcpServer::makeWithPort(&loop_, opts.port, /*useIpv6=*/true); } catch (std::exception& ex) { C10D_INFO( "Failed to bind to ipv6 address, trying ipv4. Error: {}", ex.what()); - tcpServer = - UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/false); + tcpServer_ = + UvTcpServer::makeWithPort(&loop_, opts.port, /*useIpv6=*/false); } } - tcpServer->setOnConnectCallback( + tcpServer_->setOnConnectCallback( [this](auto status) { this->onConnect(status); }); - port_ = tcpServer->port(); + port_ = tcpServer_->port(); TORCH_CHECK( port_ == opts.port || opts.port == 0, // zero means use any port "listen fd ", @@ -1063,19 +1110,19 @@ void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { } LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) { - TORCH_CHECK(uv_loop_init(&loop) == 0, "Failed to init uv loop"); + TORCH_CHECK(uv_loop_init(&loop_) == 0, "Failed to init uv loop"); TORCH_CHECK( - uv_async_init(&loop, &exit_handle, LibUVStoreDaemon::on_exit_request) == + uv_async_init(&loop_, &exit_handle_, LibUVStoreDaemon::on_exit_request) == 0, "Failed to init uv async event"); - uv_handle_set_data((uv_handle_t*)&exit_handle, this); + uv_handle_set_data((uv_handle_t*)&exit_handle_, this); } LibUVStoreDaemon::~LibUVStoreDaemon() { if (!is_running()) { - uv_close((uv_handle_t*)&exit_handle, nullptr); - uv_run(&loop, UV_RUN_NOWAIT); - TORCH_CHECK(uv_loop_close(&loop) == 0, "loop cleanup didn't work"); + uv_close((uv_handle_t*)&exit_handle_, nullptr); + uv_run(&loop_, UV_RUN_NOWAIT); + TORCH_CHECK(uv_loop_close(&loop_) == 0, "loop cleanup didn't work"); } else { // the daemon thread cleanup libuv dispose(); @@ -1098,7 +1145,7 @@ void LibUVStoreDaemon::run() { c10::setThreadName("pt_tcpstore_uv"); C10D_DEBUG("Uv main loop running"); - int res = uv_run(&loop, UV_RUN_DEFAULT); + int res = uv_run(&loop_, UV_RUN_DEFAULT); if (res) { C10D_DEBUG("UV main loop done: res:{}", res); } @@ -1107,21 +1154,21 @@ void LibUVStoreDaemon::run() { if (debug_enabled) { C10D_DEBUG("Walking live handles prior to closing clients"); - uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr); + uv_walk(&loop_, LibUVStoreDaemon::print_active_handles, nullptr); } for (const auto& client : clients_) { client->close(); } - tcpServer->close(); + tcpServer_->close(); if (debug_enabled) { C10D_DEBUG("Walking live handles after closing clients"); - uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr); + uv_walk(&loop_, LibUVStoreDaemon::print_active_handles, nullptr); } while (true) { - res = uv_loop_close(&loop); + res = uv_loop_close(&loop_); if (res == 0) { break; } @@ -1130,7 +1177,7 @@ void LibUVStoreDaemon::run() { res, uv_err_name(res), uv_strerror(res)); - res = uv_run(&loop, UV_RUN_NOWAIT); + res = uv_run(&loop_, UV_RUN_NOWAIT); if (res != 0) { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } @@ -1139,7 +1186,7 @@ void LibUVStoreDaemon::run() { } void LibUVStoreDaemon::stop() { - int res = uv_async_send(&exit_handle); + int res = uv_async_send(&exit_handle_); if (res) { C10D_WARNING( "uv_async_send failed with:{} errn:{} desc:{}\n", diff --git a/torch/csrc/distributed/c10d/Utils.cu b/torch/csrc/distributed/c10d/Utils.cu deleted file mode 100644 index ae2017efdf8d97..00000000000000 --- a/torch/csrc/distributed/c10d/Utils.cu +++ /dev/null @@ -1,49 +0,0 @@ -#include -#include -#include -#include -#include -#include - -namespace c10d { - -// CUDA kernel to check if data has NAN, device side assert -// is raised if NAN is found -template -__global__ void checkForNaN(T* data, size_t size) { - size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - size_t stride = blockDim.x * gridDim.x; - - for (size_t i = tid; i < size; i += stride) { - CUDA_KERNEL_ASSERT(!isnan(data[i])); - } -} - -// CHECK if a Tensor contains NAN in any of its element -void checkForNan(const at::Tensor& tensor) { - // skip check for non float types - if (!torch::is_floating_point(tensor)) { - return; - } - const size_t maxNumThreadsPerBlock = 256; - const size_t maxNumBlocks = 24; - const size_t numThreadsPerBlock = - std::min(maxNumThreadsPerBlock, tensor.numel()); - - const size_t numBlocks = std::min( - maxNumBlocks, - (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock); - - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - tensor.scalar_type(), - "checkForNaN", - [&] { - checkForNaN<<>>( - tensor.data_ptr(), tensor.numel()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 5c7365393404fc..ea4a4653bc35fc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -611,8 +611,6 @@ using SizeType = uint64_t; // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) -void checkForNan(const at::Tensor& tensor); - namespace tcputil { // Send and receive diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index 6a45392bc86231..8beb8f29362080 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,7 +1,6 @@ #include #include -#include #include namespace c10d { @@ -46,17 +45,17 @@ OpType Work::retrieveOpType() const { Work::~Work() = default; bool Work::isCompleted() { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return completed_; } bool Work::isSuccess() const { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return !exception_; } std::exception_ptr Work::exception() const { - C10D_LOCK_GUARD(lock, mutex_); + std::lock_guard lock(mutex_); return exception_; } @@ -74,7 +73,7 @@ std::vector Work::result() { void Work::synchronize() {} bool Work::wait(std::chrono::milliseconds timeout) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); if (timeout == kNoTimeout) { // This waits without a timeout. cv_.wait(lock, [&] { return completed_; }); @@ -104,7 +103,7 @@ c10::intrusive_ptr Work::getFuture() { } void Work::finish(std::exception_ptr exception) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { @@ -116,7 +115,7 @@ void Work::finish(std::exception_ptr exception) { } void Work::finishAndThrow(std::exception_ptr exception) { - C10D_LOCK_GUARD(lock, mutex_); + std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index ea5853aefb4049..c10e5007b9f544 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -126,8 +126,8 @@ class TORCH_API Work : public torch::CustomClassHolder { // provided by the user. void finishAndThrow(std::exception_ptr exception); - mutable std::timed_mutex mutex_; - std::condition_variable_any cv_; + mutable std::mutex mutex_; + std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index c4d6b2f61fcbbc..047459b965589a 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -8,6 +7,15 @@ #include #include +// NS: TODO: Use `std::filesystem` regardless of OS when it's possible +// to use it without leaking symbols on PRECXX11 ABI Linux OSes +// See https://github.com/pytorch/pytorch/issues/133437 for more details +#ifdef _WIN32 +#include +#else +#include +#endif + namespace c10d::control_plane { namespace { @@ -70,6 +78,15 @@ std::string jsonStrEscape(const std::string& str) { } return ostream.str(); } + +bool file_exists(const std::string& path) { +#ifdef _WIN32 + return std::filesystem::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} } // namespace WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { @@ -144,7 +161,7 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { // using unix sockets server_.set_address_family(AF_UNIX); - if (std::filesystem::exists(hostOrFile)) { + if (file_exists(hostOrFile)) { throw std::runtime_error(fmt::format("{} already exists", hostOrFile)); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e3ed6d6bd4bcb4..0f7792e64e5faa 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1556,15 +1556,10 @@ Example:: py::arg("master_listen_fd") = py::none(), py::arg("use_libuv") = true, py::call_guard()) - .def( - "collect_client_counters", - &::c10d::TCPStore::collectClientCounters, - "Return a dict of counters for tcp store client") .def_property_readonly( "host", &::c10d::TCPStore::getHost, R"(Gets the hostname on which the store listens for requests.)") - .def_property_readonly( "port", &::c10d::TCPStore::getPort, @@ -1818,8 +1813,7 @@ communication mechanism. py::init< const c10::intrusive_ptr<::c10d::Store>&, int, - int, - c10::intrusive_ptr<::c10d::ProcessGroup::Options>>(), + int>(), py::call_guard()) .def("rank", &::c10d::ProcessGroup::getRank) .def("size", &::c10d::ProcessGroup::getSize) @@ -1829,7 +1823,6 @@ communication mechanism. "_backend_id", &::c10d::ProcessGroup::getBackendID, py::arg("backend_type")) - .def_property_readonly("options", &::c10d::ProcessGroup::getOptions) .def( "broadcast", &::c10d::ProcessGroup::broadcast, @@ -2139,6 +2132,14 @@ communication mechanism. }, py::arg("device"), py::call_guard()) + .def( + "_set_default_backend", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const ::c10d::ProcessGroup::BackendType& backendType) { + return self->setDefaultBackend(backendType); + }, + py::arg("backend_type"), + py::call_guard()) .def( "_register_on_completion_hook", [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, @@ -2242,27 +2243,6 @@ The hook must have the following signature: .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) .export_values(); - // base ProcessGroup::Options binding - auto processGroupOptions = - intrusive_ptr_class_<::c10d::ProcessGroup::Options>( - processGroup, - "Options", - R"( -Base class for all processes group options implementations, such as the nccl -options :class:`~torch.distributed.ProcessGroupNCCL.Options`). -)") - .def( - py::init([](const std::string& backend, - const std::chrono::milliseconds& timeout) { - return c10::make_intrusive<::c10d::ProcessGroup::Options>( - backend, timeout); - }), - py::arg("backend"), - py::arg("timeout") = kProcessGroupDefaultTimeout, - py::call_guard()) - .def_readonly("backend", &::c10d::ProcessGroup::Options::backend) - .def_readwrite("_timeout", &::c10d::ProcessGroup::Options::timeout); - // TODO: The collection definitions handles direct instantiation of // ProcessGroup subclasses (e.g. dist.ProcessGroupGloo). This is not supported // and should be removed once all tests are transitioned @@ -2547,7 +2527,8 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::call_guard()) .def( "eager_connect_single_device", - &::c10d::Backend::eagerConnectSingleDevice) + &::c10d::Backend::eagerConnectSingleDevice, + py::call_guard()) .def( "_get_backend_name", &::c10d::Backend::getBackendName, @@ -2561,6 +2542,29 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). &::c10d::Backend::endCoalescing, py::call_guard()); + // base Backend::Options binding + // TODO: Maybe we can consider how to merge this with + // `DistributedBackendOptions`. + auto backendOptions = + intrusive_ptr_class_<::c10d::Backend::Options>( + backend, + "Options", + R"( +Base class for all backend options implementations, such as the nccl +options :class:`~torch.distributed.ProcessGroupNCCL.Options`). +)") + .def( + py::init([](const std::string& backend, + const std::chrono::milliseconds& timeout) { + return c10::make_intrusive<::c10d::Backend::Options>( + backend, timeout); + }), + py::arg("backend"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::call_guard()) + .def_readonly("backend", &::c10d::Backend::Options::backend) + .def_readwrite("_timeout", &::c10d::Backend::Options::timeout); + #ifdef USE_C10D_GLOO static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; @@ -2572,7 +2576,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). shared_ptr_class_<::gloo::transport::Device>(processGroupGloo, "Device"); intrusive_ptr_class_<::c10d::ProcessGroupGloo::Options>( - processGroupGloo, "_Options", processGroupOptions) + processGroupGloo, "_Options", backendOptions) .def(py::init<>()) .def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices) .def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads); @@ -2799,7 +2803,7 @@ for details. intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( processGroupNCCL, "Options", - processGroupOptions, + backendOptions, R"( ProcessGroup options for the NCCL backend diff --git a/torch/csrc/distributed/c10d/logging.cpp b/torch/csrc/distributed/c10d/logging.cpp index 68087312fd68cf..5d05b5a3a5a88d 100644 --- a/torch/csrc/distributed/c10d/logging.cpp +++ b/torch/csrc/distributed/c10d/logging.cpp @@ -34,20 +34,4 @@ bool isLogLevelEnabled(LogLevel level) noexcept { return false; } -void lockWithLogging( - std::unique_lock& lock, - std::chrono::milliseconds log_interval, - c10::string_view desc, - c10::string_view file, - int line) { - while (!lock.try_lock_for(log_interval)) { - C10D_WARNING( - "{}:{} {}: waiting for lock for {}ms", - file, - line, - desc, - log_interval.count()); - } -} - } // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/logging.h b/torch/csrc/distributed/c10d/logging.h index 38679a3ac7d42e..a7cc82f702eea2 100644 --- a/torch/csrc/distributed/c10d/logging.h +++ b/torch/csrc/distributed/c10d/logging.h @@ -6,7 +6,6 @@ #pragma once -#include #include #include @@ -25,44 +24,25 @@ std::string formatLogMessage(fmt::string_view fmt, T&&... args) { return fmt::vformat(fmt, fmt::make_format_args(args...)); } -// logWithLogging is a wrapper around std::unique_lock -// that automatically logs if the lock cannot be acquired within a given -// timeout. -TORCH_API void lockWithLogging( - std::unique_lock& lock, - std::chrono::milliseconds log_interval, - c10::string_view desc, - c10::string_view file, - int line); - } // namespace detail } // namespace c10d -#define C10D_ERROR(...) \ - LOG_IF( \ - ERROR, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error)) \ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_ERROR(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error)) \ + LOG(ERROR) << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) #define C10D_WARNING(...) \ - LOG_IF( \ - WARNING, \ - c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Warning)) \ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) - -#define C10D_INFO(...) \ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Info)) \ - << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Warning)) \ + LOG(WARNING) << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) -#define C10D_DEBUG(...) \ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug)) \ - << "[c10d - debug] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_INFO(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Info)) \ + LOG(INFO) << "[c10d] " << c10d::detail::formatLogMessage(__VA_ARGS__) -#define C10D_TRACE(...) \ - LOG_IF(INFO, c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \ - << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) +#define C10D_DEBUG(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug)) \ + LOG(INFO) << "[c10d - debug] " << c10d::detail::formatLogMessage(__VA_ARGS__) -// TODO: use std::source_location() when we can use C++20 -#define C10D_LOCK_GUARD(name, mutex) \ - std::unique_lock name{mutex, std::defer_lock}; \ - ::c10d::detail::lockWithLogging( \ - name, std::chrono::seconds(30), #mutex, __FILE__, __LINE__) +#define C10D_TRACE(...) \ + if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Trace)) \ + LOG(INFO) << "[c10d - trace] " << c10d::detail::formatLogMessage(__VA_ARGS__) diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index f155f825284236..db4519d7b2ad38 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -37,6 +37,7 @@ C10_DIAGNOSTIC_POP() #include #include #include +#include #include @@ -193,6 +194,43 @@ class SocketImpl { Handle hnd_; const std::optional remote_; }; + +std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { + char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT + + if (int err = ::getnameinfo( + addr, len, host, NI_MAXHOST, port, NI_MAXSERV, NI_NUMERICSERV)) { + C10D_WARNING( + "The hostname of the client socket cannot be retrieved. err={}", err); + + // if we can't resolve the hostname, display the IP address + if (addr->sa_family == AF_INET) { + struct sockaddr_in* psai = (struct sockaddr_in*)&addr; + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != + NULL) { + return fmt::format("{}:{}", ip, psai->sin_port); + } + } else if (addr->sa_family == AF_INET6) { + struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; + char ip[INET6_ADDRSTRLEN]; + if (inet_ntop( + addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != + NULL) { + return fmt::format("[{}]:{}", ip, psai->sin6_port); + } + } + + C10_THROW_ERROR( + DistNetworkError, + fmt::format( + "failed to format addr, unknown family={}", addr->sa_family)); + } + if (addr->sa_family == AF_INET) { + return fmt::format("{}:{}", host, port); + } + return fmt::format("[{}]:{}", host, port); +} } // namespace c10d::detail // @@ -208,45 +246,10 @@ struct formatter<::addrinfo> { template decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) const { - char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT - - int r = ::getnameinfo( - addr.ai_addr, - addr.ai_addrlen, - host, - NI_MAXHOST, - port, - NI_MAXSERV, - NI_NUMERICSERV); - if (r != 0) { - // if we can't resolve the hostname, display the IP address - if (addr.ai_family == AF_INET) { - struct sockaddr_in* psai = (struct sockaddr_in*)addr.ai_addr; - char ip[INET_ADDRSTRLEN]; - if (inet_ntop(addr.ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != - NULL) { - return fmt::format_to(ctx.out(), "{}:{}", ip, psai->sin_port); - } - } else if (addr.ai_family == AF_INET6) { - struct sockaddr_in6* psai = (struct sockaddr_in6*)addr.ai_addr; - char ip[INET6_ADDRSTRLEN]; - if (inet_ntop( - addr.ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != - NULL) { - return fmt::format_to(ctx.out(), "[{}]:{}", ip, psai->sin6_port); - } - } - C10_THROW_ERROR( - DistNetworkError, - fmt::format( - "failed to format addr, unknown family={}", addr.ai_family)); - } - - if (addr.ai_addr->sa_family == AF_INET) { - return fmt::format_to(ctx.out(), "{}:{}", host, port); - } else { - return fmt::format_to(ctx.out(), "[{}]:{}", host, port); - } + return fmt::format_to( + ctx.out(), + "{}", + c10d::detail::formatSockAddr(addr.ai_addr, addr.ai_addrlen)); } }; diff --git a/torch/csrc/distributed/c10d/socket.h b/torch/csrc/distributed/c10d/socket.h index 1e42a53b0b530d..de9bd6989c290e 100644 --- a/torch/csrc/distributed/c10d/socket.h +++ b/torch/csrc/distributed/c10d/socket.h @@ -103,7 +103,5 @@ class Socket { std::unique_ptr impl_; }; - } // namespace detail - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/socket_fmt.h b/torch/csrc/distributed/c10d/socket_fmt.h new file mode 100644 index 00000000000000..8c7832ebf933cd --- /dev/null +++ b/torch/csrc/distributed/c10d/socket_fmt.h @@ -0,0 +1,32 @@ +// (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +/* +This file should not be included from other .h files and only used in cpp files +as it exposes the underlying platform specific socket headers. +*/ + +#include + +#ifdef _WIN32 +#include + +#include +#include +#else +#include +#endif + +namespace c10d { +namespace detail { + +// Returns a human-readable representation of the given socket address. +std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len); + +} // namespace detail +} // namespace c10d diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index e65b503e574f16..e93be25483051c 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -477,7 +477,7 @@ void writeWrappedPayload( int64_t indexToWrite = originalPayload.size(); originalPayload.resize(originalPayload.size() + sizeof(int64_t)); const int64_t additionalPayloadSize = additionalPayload.size(); - torch::utils::THP_encodeInt64Buffer( + torch::utils::THP_encodeBuffer( reinterpret_cast(originalPayload.data()) + indexToWrite, &additionalPayloadSize, torch::utils::THPByteOrder::THP_BIG_ENDIAN, @@ -492,7 +492,7 @@ std::vector readWrappedPayload( int64_t additionalPayloadSize; TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t)); size_t indexToRead = payload.size() - sizeof(int64_t); - torch::utils::THP_decodeInt64Buffer( + torch::utils::THP_decodeBuffer( &additionalPayloadSize, reinterpret_cast(payload.data()) + indexToRead, torch::utils::THPByteOrder::THP_BIG_ENDIAN, diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index 8a5ab17e57a490..ea61825d614e86 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -4,11 +4,11 @@ #include #include -CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) { +CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) + : backend{backend} { this->check_fn = guarded_code.attr("check_fn"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); - this->backend = backend; // TODO - clean this up when enable_cpp_guard_manager is True by default if (py::hasattr(this->check_fn, "root")) { this->root_mgr = torch::dynamo::convert_to_root_guard_manager( @@ -16,11 +16,17 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) { } } +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( + "-Wdeprecated-copy-with-user-provided-dtor") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") +// NOLINTNEXTLINE(bugprone-exception-escape) CacheEntry::~CacheEntry() { // prevent check_fn from use-after-free when invalidating this->check_fn.attr("cache_entry") = py::none(); this->check_fn.attr("extra_state") = py::none(); } +C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() py::object CacheEntry::next() { NULL_CHECK(this->_owner); @@ -48,8 +54,5 @@ PyObject* get_backend(PyObject* callback) { while (py::hasattr(handle, "_torchdynamo_orig_callable")) { handle = handle.attr("_torchdynamo_orig_callable"); } - if (py::hasattr(handle, "compiler_fn")) { - handle = handle.attr("compiler_fn"); - } return handle.ptr(); } diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h index af909dc749b444..3d2391d23f8475 100644 --- a/torch/csrc/dynamo/cache_entry.h +++ b/torch/csrc/dynamo/cache_entry.h @@ -36,6 +36,9 @@ typedef struct ExtraState ExtraState; #ifdef __cplusplus +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( + "-Wdeprecated-copy-with-user-provided-dtor") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") typedef struct VISIBILITY_HIDDEN CacheEntry { // check the guards: lambda: : bool py::object check_fn; @@ -58,6 +61,8 @@ typedef struct VISIBILITY_HIDDEN CacheEntry { // Warning: returns a reference whose lifetime is controlled by C++ py::object next(); } CacheEntry; +C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() #endif diff --git a/torch/csrc/dynamo/cpp_shim.cpp b/torch/csrc/dynamo/cpp_shim.cpp index 35c415fe57425a..84c6e0baaf8d95 100644 --- a/torch/csrc/dynamo/cpp_shim.cpp +++ b/torch/csrc/dynamo/cpp_shim.cpp @@ -1,6 +1,5 @@ -#include - #include +#include struct _PytorchRecordFunctionState { at::RecordFunction guard; @@ -14,6 +13,24 @@ _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) { return state; } +static inline _PytorchRecordFunctionState* +_pytorch_record_function_enter_with_kwinputs( + const char* name, + const std::unordered_map* kwargs) { + _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState(); + std::vector args; + state->guard.before(name, &args, kwargs); + return state; +} + +_PytorchRecordFunctionState* _pytorch_record_function_enter_with_context( + const char* name, + const char* context) { + auto map = std::unordered_map(); + map.insert({"context", c10::IValue(context)}); + return _pytorch_record_function_enter_with_kwinputs(name, &map); +} + void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) { if (state == nullptr) { return; diff --git a/torch/csrc/dynamo/cpp_shim.h b/torch/csrc/dynamo/cpp_shim.h index 5baf67805b06c3..b5ec73a3bbfaaa 100644 --- a/torch/csrc/dynamo/cpp_shim.h +++ b/torch/csrc/dynamo/cpp_shim.h @@ -1,5 +1,4 @@ #pragma once - #ifdef __cplusplus extern "C" { #endif @@ -8,6 +7,9 @@ struct _PytorchRecordFunctionState; typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState; _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name); +_PytorchRecordFunctionState* _pytorch_record_function_enter_with_context( + const char* name, + const char* context); void _pytorch_record_function_exit(_PytorchRecordFunctionState* state); #ifdef __cplusplus diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 35032f07770844..181acbdf4946a6 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -10,9 +10,11 @@ #include #include +#define MAX_COMPILE_CONTEXT_SIZE 100 + PyObject* guard_error_hook = NULL; const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; - +static char compile_context[MAX_COMPILE_CONTEXT_SIZE]; static int active_dynamo_threads = 0; static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT; @@ -134,7 +136,7 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = { static PyTypeObject THPPyInterpreterFrameType = { PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame", + .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame", .tp_basicsize = sizeof(THPPyInterpreterFrame), .tp_flags = Py_TPFLAGS_DEFAULT, .tp_getset = THPPyInterpreterFrame_properties, @@ -483,7 +485,8 @@ inline static PyObject* eval_custom_code( PyCodeObject* code, int throw_flag, int free_vars_copied) { - _PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region"); + const char* trace_id = compile_context; + _PytorchRecordFunctionState* rf = _pytorch_record_function_enter_with_context("Torch-Compiled Region", trace_id); PyObject* result = eval_custom_code_impl( tstate, frame, @@ -518,6 +521,8 @@ static PyObject* _custom_eval_frame_shim( return result; } +static PyObject* skip_code_recursive_flag; + // NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping // the frame, meaning that unless we default evaluate the original frame, // we are responsible for clearing it - via clear_old_frame_if_python_312_plus. @@ -577,6 +582,13 @@ static PyObject* _custom_eval_frame( DEBUG_TRACE("skip %s", get_frame_name(frame)); return eval_frame_default(tstate, frame, throw_flag); } + if (extra == SKIP_CODE_RECURSIVE) { + DEBUG_TRACE("skip recursive %s", get_frame_name(frame)); + eval_frame_callback_set(Py_None); + PyObject* result = eval_frame_default(tstate, frame, throw_flag); + eval_frame_callback_set(callback); + return result; + } if (extra == NULL) { extra = init_and_set_extra_state(F_CODE(frame)); @@ -665,6 +677,14 @@ static PyObject* _custom_eval_frame( // inside the torch.compile block we won't try to Dynamo anything else. *should_clear_frame = 1; return NULL; + } else if (result == skip_code_recursive_flag) { + // Dynamo returned skip_code_recursive_flag, so we should recursively skip code. + DEBUG_TRACE("create skip recursive %s", get_frame_name(frame)); + set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE); + PyObject* r = eval_frame_default(tstate, frame, throw_flag); + // Re-enable custom behavior + eval_frame_callback_set(callback); + return r; } else if (result != Py_None) { DEBUG_TRACE("create cache %s", get_frame_name(frame)); @@ -709,7 +729,7 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL}; static PyTypeObject THPPyInterpreterFrameType = { PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame", + .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame", .tp_basicsize = sizeof(THPPyInterpreterFrame), .tp_flags = Py_TPFLAGS_DEFAULT, .tp_getset = THPPyInterpreterFrame_properties, @@ -817,12 +837,27 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { Py_RETURN_NONE; } +static PyObject* set_context_frame(PyObject* dummy, PyObject* obj) { + int frame_id, frame_compile_id, attempt; + if (!PyArg_ParseTuple(obj, "iii", &frame_id, &frame_compile_id, &attempt)) { + PyErr_SetString(PyExc_TypeError, "Expected three integers"); + return NULL; + } + if (attempt == 0) { + sprintf(compile_context, "%d/%d", frame_id, frame_compile_id); + } else { + sprintf(compile_context, "%d/%d_%d", frame_id, frame_compile_id, attempt); + } + Py_RETURN_NONE; +} + static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame_py, METH_O, NULL}, {"reset_code", reset_code, METH_O, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL}, {"skip_code", skip_code, METH_O, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, + {"set_context_frame", set_context_frame, METH_O, NULL}, {NULL, NULL, 0, NULL}}; static struct PyModuleDef _module = { @@ -865,5 +900,13 @@ PyObject* torch_c_dynamo_eval_frame_init(void) { } #endif + skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type); + if (skip_code_recursive_flag == NULL) { + return NULL; + } + if (PyModule_AddObject(module, "skip_code_recursive_flag", skip_code_recursive_flag) != 0) { + return NULL; + } + return module; } diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 01d29eab519716..d661f57d2cf32b 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -38,15 +38,20 @@ void ExtraState::invalidate(CacheEntry* cache_entry) { this->cache_entry_list.erase(cache_entry->_owner_loc); } +static bool is_extra_state_unset(ExtraState* extra_state) { + return extra_state == nullptr || extra_state == SKIP_CODE || + extra_state == SKIP_CODE_RECURSIVE; +} + CacheEntry* extract_cache_entry(ExtraState* extra_state) { - if (extra_state == nullptr || extra_state == SKIP_CODE) { + if (is_extra_state_unset(extra_state)) { return nullptr; } return extra_state->get_first_entry(); } FrameState* extract_frame_state(ExtraState* extra_state) { - if (extra_state == nullptr || extra_state == SKIP_CODE) { + if (is_extra_state_unset(extra_state)) { return nullptr; } return (FrameState*)extra_state->frame_state.ptr(); @@ -60,16 +65,14 @@ ExtraState* get_extra_state(PyCodeObject* code) { void destroy_extra_state(void* obj) { ExtraState* extra = (ExtraState*)obj; - if (extra != nullptr && extra != SKIP_CODE) { + if (!is_extra_state_unset(extra)) { delete extra; } } void set_extra_state(PyCodeObject* code, ExtraState* extra_state) { ExtraState* old_extra_state = get_extra_state(code); - CHECK( - old_extra_state == nullptr || old_extra_state == SKIP_CODE || - old_extra_state != extra_state); + CHECK(is_extra_state_unset(extra_state) || old_extra_state != extra_state); _PyCode_SetExtra((PyObject*)code, extra_index, extra_state); } @@ -80,19 +83,37 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) { ExtraState* extra_state = new ExtraState(); NULL_CHECK(extra_state); set_extra_state(code, extra_state); + // freed by destroy_extra_state (since we need to pass these objects to C) + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) return extra_state; } +bool backend_match(PyObject* saved_backend, PyObject* backend) { + // Pointer equality check for common case + if (saved_backend != backend) { + // The Py_TYPE check should not be required but there is a pre-existing + // issue where backend is possibly deallocated (or nullptr) and causes + // segfaults. Check test - test_inplace_custom_op_intermediate + return ( + Py_TYPE(saved_backend) == Py_TYPE(backend) && + PyObject_RichCompareBool(saved_backend, backend, Py_EQ)); + } + return true; +} + PyObject* lookup( ExtraState* extra_state, PyObject* f_locals, - const PyObject* backend) { + PyObject* backend) { size_t index = 0; CacheEntry* found = nullptr; py::handle locals(f_locals); for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. - bool valid = backend == Py_False || cache_entry.backend == backend; + + bool valid = + backend == Py_False || backend_match(cache_entry.backend, backend); + if (valid) { try { // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is @@ -157,7 +178,7 @@ py::list _debug_get_cache_entry_list(const py::handle& code_obj) { PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; - if (extra && extra != SKIP_CODE) { + if (!is_extra_state_unset(extra)) { for (CacheEntry& e : extra->cache_entry_list) { result.append(py::cast(e, py::return_value_policy::reference)); } diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 48784f9e4a6e00..1f6ccc7061a0c5 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -16,6 +16,8 @@ extern "C" { // Flag to just run a frame normally #define SKIP_CODE ((void*)0x1) +// Flag to run a frame and any recursive calls normally +#define SKIP_CODE_RECURSIVE ((void*)0x2) // Points to the extra scratch space on the code object extern Py_ssize_t extra_index; @@ -97,7 +99,7 @@ void destroy_extra_state(void* obj); // - there is no return, but the extra_state is stolen, so it becomes // set_extra_state responsibility to clean it up. It will be deleted during // the reset_code/skip, when the set_extra_state is called with -// NULL/SKIP_CODE. +// NULL/SKIP_CODE/SKIP_CODE_RECURSIVE. // Invariant - Dont set the extra state for the extra state that is already on // the code object. Otherwise, we will first free up the old extra state @@ -127,7 +129,7 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code); PyObject* lookup( ExtraState* extra_state, PyObject* f_locals, - const PyObject* backend); + PyObject* backend); // Create a new cache entry at extra_state holding on to guarded_code. // Ownership contract diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 0c2cf51a2cbbbb..1d1cf2fb0c60a6 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2515,62 +2515,40 @@ class TORCH_FUNCTION_MODE_STACK : public LeafGuard { public: TORCH_FUNCTION_MODE_STACK( const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) - : LeafGuard(std::move(verbose_code_parts)), - _ref_stack(), - _ignored_types() { + : LeafGuard(std::move(verbose_code_parts)), _ref_stack() { Py_ssize_t len = PyList_Size(initial_stack.ptr()); for (Py_ssize_t idx = 0; idx < len; idx++) { PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref - this->_ref_stack.push_back(Py_TYPE(mode)); - } - - len = PyList_Size(ignored_types.ptr()); - for (Py_ssize_t idx = 0; idx < len; idx++) { - PyObject* type_obj = - PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref - if (PyType_Check(type_obj) == 0) { - PyErr_SetString( - PyExc_TypeError, "ignored_types should contain a list of types"); - return; - } - PyTypeObject* type = (PyTypeObject*)type_obj; - this->_ignored_types.insert(type); + auto type = Py_TYPE(mode); + this->_ref_stack.push_back(type); } } bool check_nopybind(PyObject* value) override { // Ignore value arg, only used to satisfy the interface - size_t ref_ind = 0; - int64_t len = at::impl::PythonTorchFunctionTLS::stack_len(); + const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len(); const size_t ref_stack_size = this->_ref_stack.size(); - for (int64_t idx = 0; idx < len; idx++) { + if (len != ref_stack_size) { + return false; + } + + for (int64_t idx = 0; (size_t)idx < len; idx++) { std::shared_ptr mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter())); - // skip ignored types - if (this->_ignored_types.count(mode_type) > 0) { - continue; - } - // if we already have more non-ignored modes than the ref stack - // or if the mode doesn't match at the current index, return false - else if ( - (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) || - mode_type != _ref_stack[ref_ind]) { + if (mode_type != _ref_stack.at(idx)) { return false; } - ref_ind++; } - return ref_ind == this->_ref_stack.size(); + return true; } private: std::vector _ref_stack; - std::set _ignored_types; }; class TENSOR_MATCH : public LeafGuard { @@ -3438,6 +3416,69 @@ class WeakRefCallGuardAccessor : public GuardAccessor { } }; +/** + * Implements function call no args - e.g, torch.cuda.current_device() + */ +class CallFunctionNoArgsGuardAccessor : public GuardAccessor { + public: + CallFunctionNoArgsGuardAccessor( + RootGuardManager* root, + py::str name, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) + : GuardAccessor( + root, + std::move(name), + std::move(source), + example_value, + guard_manager_enum) {} + + // NB: Intentional duplication between check_nopybind and + // check_verbose_nopybind. + bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) + override { // borrowed ref + if (!PyCallable_Check(obj)) { + return false; + } + + PyObject* x = PyObject_CallNoArgs(obj); + if (x == nullptr) { + // Call failed, clear the exception and return false. + PyErr_Clear(); + return false; + } + + bool result = _guard_manager->check_nopybind(x); + Py_DECREF(x); + return result; + } + + GuardDebugInfo check_verbose_nopybind( + PyObject* obj) override { // borrowed ref + if (!PyCallable_Check(obj)) { + return GuardDebugInfo( + false, std::string("Not a callable obj ") + get_source(), 0); + } + + PyObject* x = PyObject_CallNoArgs(obj); + if (x == nullptr) { + // Call failed, clear the exception and return debug info. + std::string exc_message = get_exception_message(); + PyErr_Clear(); + return GuardDebugInfo(false, exc_message, 0); + } + + GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x); + Py_DECREF(x); + return result; + } + + std::string repr() const override { + return "CallFunctionNoArgsGuardAccessor()"; + } +}; + /** * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to * supply accessor as a python function. This is useful for from_numpy source. @@ -3672,7 +3713,7 @@ PyObject* torch_c_dynamo_guards_init() { LeafGuard, std::shared_ptr>( py_m, "TORCH_FUNCTION_MODE_STACK") - .def(py::init()) + .def(py::init()) .def("__call__", &TORCH_FUNCTION_MODE_STACK::check); py::class_>( py_m, "DATA_PTR_MATCH") @@ -3782,6 +3823,12 @@ PyObject* torch_c_dynamo_guards_init() { std::unique_ptr>( py_m, "WeakRefCallGuardAccessor"); // NOLINTNEXTLINE(bugprone-unused-raii) + py::class_< + CallFunctionNoArgsGuardAccessor, + GuardAccessor, + std::unique_ptr>( + py_m, "CallFunctionNoArgsGuardAccessor"); + // NOLINTNEXTLINE(bugprone-unused-raii) py::class_< TupleIteratorGetItemAccessor, GuardAccessor, @@ -3903,10 +3950,9 @@ PyObject* torch_c_dynamo_guards_init() { "add_torch_function_mode_stack_guard", [](GuardManager& self, const py::list& initial_stack, - const py::list& ignored_types, py::object verbose_code_parts) -> void { self.add_leaf_guard(std::make_shared( - initial_stack, ignored_types, std::move(verbose_code_parts))); + initial_stack, std::move(verbose_code_parts))); }) .def( "add_data_ptr_guard", @@ -4101,6 +4147,26 @@ PyObject* torch_c_dynamo_guards_init() { py::return_value_policy::reference) // return by reference because GuardManager has the ownership of accessors // and guard managers + .def( + "call_function_no_args_manager", + [](GuardManager& self, + std::string source, + py::handle example_value, + py::handle guard_manager_enum) -> GuardManager* { + // A unique key is used to save as the accessor key. + py::str unique_key("__call_function_no_args_accessor__"); + return self.get_child_manager( + std::move(unique_key), + std::move(source), + example_value, + guard_manager_enum); + }, + py::arg("source"), + py::arg("example_value"), + py::arg("guard_manager_enum"), + py::return_value_policy::reference) + // return by reference because GuardManager has the ownership of accessors + // and guard managers .def( "tuple_iterator_getitem_manager", &GuardManager::get_child_manager, diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index bf20bfe724f4d0..3d60a75b4561d1 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -44,6 +45,11 @@ void initDynamoBindings(PyObject* torch) { throw python_error(); } + PyObject* utils = torch_c_dynamo_utils_init(); + if (utils == nullptr || PyModule_AddObject(dynamo, "utils", utils) != 0) { + throw python_error(); + } + PyObject* guards = torch_c_dynamo_guards_init(); if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) { throw python_error(); diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index ba008144d68c5d..da8b2888f25281 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -685,6 +685,24 @@ CacheNode* _compiled_autograd_impl( return cache; } +struct LockGuardWithErrorLogs { + LockGuardWithErrorLogs(std::mutex& mtx) : mtx_(mtx) { + // Note: the standard allows try_lock to fail spuriously during races for + // performance reasons, but it shouldn't happen here since we: + // 1. disable multithreaded autograd + // 2. plenty of latency between backward calls + TORCH_INTERNAL_ASSERT( + mtx_.try_lock(), + "Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet."); + } + + ~LockGuardWithErrorLogs() { + mtx_.unlock(); + } + + std::mutex& mtx_; +}; + variable_list compiled_autograd( const std::shared_ptr& graph_root, GraphTask& graph_task, @@ -693,8 +711,8 @@ variable_list compiled_autograd( TORCH_CHECK( c10::impl::TorchDispatchModeTLS::stack_len() == 0, "TorchDispatchMode not yet implemented for compiled autograd") - static std::mutex lock; - std::lock_guard lock_guard(lock); + static std::mutex mtx; + LockGuardWithErrorLogs lock_guard(mtx); pybind11::gil_scoped_acquire gil; at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_); diff --git a/torch/csrc/dynamo/utils.cpp b/torch/csrc/dynamo/utils.cpp new file mode 100644 index 00000000000000..662b16bfb567e4 --- /dev/null +++ b/torch/csrc/dynamo/utils.cpp @@ -0,0 +1,33 @@ +#include + +namespace torch::dynamo { + +static std::array _methods = {{ + {nullptr, + nullptr, + 0, + nullptr} // Sentinel value indicating the end of the array +}}; + +bool is_instancemethod(py::object obj) { + return PyInstanceMethod_Check(obj.ptr()); +} + +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + "torch._C._dynamo.utils", + "Module containing C utils", + -1, + _methods.data()}; + +PyObject* torch_c_dynamo_utils_init() { + auto m = PyModule_Create(&_module); + if (m == nullptr) + return nullptr; + + auto py_m = py::handle(m).cast(); + py_m.def("is_instancemethod", is_instancemethod); + return m; +} + +} // namespace torch::dynamo diff --git a/torch/csrc/dynamo/utils.h b/torch/csrc/dynamo/utils.h index 3fd932f0441fe0..c7bcfb5b5e643f 100644 --- a/torch/csrc/dynamo/utils.h +++ b/torch/csrc/dynamo/utils.h @@ -1,5 +1,10 @@ #pragma once +#include +// C2039 MSVC +#include +#include +#include // The visibility attribute is to avoid a warning about storing a field in the // struct that has a different visibility (from pybind) than the struct. #ifdef _WIN32 @@ -7,3 +12,7 @@ #else #define VISIBILITY_HIDDEN __attribute__((visibility("hidden"))) #endif + +namespace torch::dynamo { +PyObject* torch_c_dynamo_utils_init(); +} // namespace torch::dynamo diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index aa9566acd95712..c364daa82e18b3 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -64,8 +64,8 @@ std::vector unpack_tensors( const c10::Device& device) { std::vector inputs; for (size_t idx = 0; idx < stack.size(); idx++) { - auto ivalue = stack[idx]; - auto ivalue_arg = arguments[idx]; + const auto& ivalue = stack[idx]; + const auto& ivalue_arg = arguments[idx]; if (ivalue.isTensor()) { unpack_tensor_ivalue(ivalue, device, inputs); } else if (ivalue.isTensorList()) { @@ -117,12 +117,10 @@ std::vector unpack_input_parameters( if (stack[idx].isScalar()) { // Beyond c10::Scalar, the floating value and interger value are also // represented as Scalar. - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toScalar(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toScalar(), arg_order); } else if (stack[idx].isTensorList()) { // tensor list - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toTensorList().vec(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toTensorList().vec(), arg_order); } else if (stack[idx].isOptionalTensorList()) { // optional tensor list: std::vector> std::vector tensor_list; @@ -131,27 +129,23 @@ std::vector unpack_input_parameters( tensor_list.push_back(item.toOptional().value()); } } - inputs_metadata.push_back(ParameterMetadata(tensor_list, arg_order)); + inputs_metadata.emplace_back(tensor_list, arg_order); } else if ( *arguments[idx].real_type() == *c10::getTypePtr>()) { // optional tensor if (stack[idx].toOptional().has_value()) { - inputs_metadata.push_back(ParameterMetadata( - stack[idx].toOptional().value(), arg_order)); + inputs_metadata.emplace_back( + stack[idx].toOptional().value(), arg_order); } } else if (stack[idx].isTensor()) { - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toTensor(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toTensor(), arg_order); } else if (stack[idx].isString()) { - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toStringRef(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toStringRef(), arg_order); } else if (stack[idx].isBool()) { - inputs_metadata.push_back( - ParameterMetadata(c10::Scalar(stack[idx].toBool()), arg_order)); + inputs_metadata.emplace_back(c10::Scalar(stack[idx].toBool()), arg_order); } else if (stack[idx].isDevice()) { - inputs_metadata.push_back( - ParameterMetadata(stack[idx].toDevice(), arg_order)); + inputs_metadata.emplace_back(stack[idx].toDevice(), arg_order); } else { TORCH_CHECK_NOT_IMPLEMENTED( false, @@ -239,7 +233,7 @@ void AOTIPythonKernelHolder::cache_hit( auto outputs = aoti_kernel_metadata.kernel_runner_->run(inputs); for (auto& output : outputs) { - stack->push_back(output); + stack->emplace_back(output); } } @@ -343,8 +337,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { auto tensor_metadata = build_tensor_metadata(metadata); test_list_metadata.push_back(tensor_metadata); } - parameter_metadata_list.push_back( - ParameterMetadata(test_list_metadata, arg_idx)); + parameter_metadata_list.emplace_back(test_list_metadata, arg_idx); } else if (is_scalar) { // Scalar auto metadata = item_metadata.cast(); @@ -367,14 +360,12 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { dtype_value); } - parameter_metadata_list.push_back( - ParameterMetadata(c10::Scalar(scalar), arg_idx)); + parameter_metadata_list.emplace_back(c10::Scalar(scalar), arg_idx); } else if (is_string) { // String auto metadata = item_metadata.cast(); auto str_value = metadata["string_value"].cast(); - parameter_metadata_list.push_back( - ParameterMetadata(str_value, arg_idx)); + parameter_metadata_list.emplace_back(str_value, arg_idx); } else if (is_dtype) { // Dtype auto metadata = item_metadata.cast(); @@ -382,8 +373,8 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { TORCH_INTERNAL_ASSERT(THPDtype_Check(dtype_value_obj.ptr())); auto dtype_value = reinterpret_cast(dtype_value_obj.ptr())->scalar_type; - parameter_metadata_list.push_back(ParameterMetadata( - c10::Scalar(static_cast(dtype_value)), arg_idx)); + parameter_metadata_list.emplace_back( + c10::Scalar(static_cast(dtype_value)), arg_idx); } else if (is_device) { // Device auto metadata = item_metadata.cast(); @@ -395,21 +386,20 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { metadata["device_index_value"].cast(); device.set_index(device_index_value); } - parameter_metadata_list.push_back(ParameterMetadata(device, arg_idx)); + parameter_metadata_list.emplace_back(device, arg_idx); } else if (is_layout) { auto metadata = item_metadata.cast(); auto layout_value_obj = metadata["layout_value"].cast(); TORCH_INTERNAL_ASSERT(THPLayout_Check(layout_value_obj.ptr())); auto layout_value = reinterpret_cast(layout_value_obj.ptr())->layout; - parameter_metadata_list.push_back(ParameterMetadata( - c10::Scalar(static_cast(layout_value)), arg_idx)); + parameter_metadata_list.emplace_back( + c10::Scalar(static_cast(layout_value)), arg_idx); } else { // Tensor auto metadata = item_metadata.cast(); auto tensor_metadata = build_tensor_metadata(metadata); - parameter_metadata_list.push_back( - ParameterMetadata(tensor_metadata, arg_idx)); + parameter_metadata_list.emplace_back(tensor_metadata, arg_idx); } } @@ -480,9 +470,12 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( schema.overload_name().empty() ? "default" : schema.overload_name(); auto pos = qualified_name.find("::"); TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); - std::string ns_str(qualified_name.begin(), qualified_name.begin() + pos); + std::string ns_str( + qualified_name.begin(), + qualified_name.begin() + static_cast(pos)); std::string func_name( - qualified_name.begin() + pos + strlen("::"), qualified_name.end()); + qualified_name.begin() + static_cast(pos + strlen("::")), + qualified_name.end()); py::gil_scoped_acquire gil; py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* { diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index a5ca876d45392f..6562447bcba4fa 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -1,6 +1,7 @@ #if !defined(C10_MOBILE) && !defined(ANDROID) #include #include +#include namespace torch::inductor { @@ -25,8 +26,8 @@ TensorMetadata::TensorMetadata( dtype_(dtype), device_(device), dispatch_key_set_(dispatch_key_set), - sizes_(sizes), - strides_(strides), + sizes_(std::move(sizes)), + strides_(std::move(strides)), requires_grad_(requires_grad) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( !is_symbolic_, "Not support symbolic shape now"); @@ -94,25 +95,24 @@ bool TensorMetadata::operator==(const TensorMetadata& other) const { std::ostream& operator<<( std::ostream& stream, const TensorMetadata& tensor_metadata) { - stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << std::endl; - stream << "dtype_: " << tensor_metadata.dtype_ << std::endl; - stream << "device_: " << tensor_metadata.device_ << std::endl; + stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << '\n'; + stream << "dtype_: " << tensor_metadata.dtype_ << '\n'; + stream << "device_: " << tensor_metadata.device_ << '\n'; stream << "sizes_: "; for (const auto& size : tensor_metadata.sizes_) { stream << size << " "; } - stream << std::endl; + stream << '\n'; stream << "strides_: "; for (const auto& stride : tensor_metadata.strides_) { stream << stride << " "; } - stream << "requires_grad_: " << tensor_metadata.requires_grad_ << std::endl; - stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ - << std::endl; + stream << "requires_grad_: " << tensor_metadata.requires_grad_ << '\n'; + stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ << '\n'; stream << "tensor_check_: " << tensor_metadata.tensor_check_.has_value() - << std::endl; - stream << std::endl; + << '\n'; + stream << '\n'; return stream; } @@ -138,8 +138,9 @@ ParameterMetadata::ParameterMetadata( uint64_t input_order) : tag_(TENSOR_LIST), order_(input_order) { std::vector tensor_metadata_list; + tensor_metadata_list.reserve(tensor_list.size()); for (const auto& tensor : tensor_list) { - tensor_metadata_list.push_back(TensorMetadata(tensor)); + tensor_metadata_list.emplace_back(tensor); } value_ = tensor_metadata_list; } @@ -147,23 +148,17 @@ ParameterMetadata::ParameterMetadata( ParameterMetadata::ParameterMetadata( const c10::Scalar& scalar, uint64_t input_order) - : tag_(SCALAR), order_(input_order) { - value_ = scalar; -} + : tag_(SCALAR), value_(scalar), order_(input_order) {} ParameterMetadata::ParameterMetadata( const std::string& str, uint64_t input_order) - : tag_(STRING), order_(input_order) { - value_ = str; -} + : tag_(STRING), value_(str), order_(input_order) {} ParameterMetadata::ParameterMetadata( const c10::Device& device, uint64_t input_order) - : tag_(DEVICE), order_(input_order) { - value_ = device; -} + : tag_(DEVICE), value_(device), order_(input_order) {} bool ParameterMetadata::operator==(const ParameterMetadata& other) const { // Same type diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h index d84ea9b477ef35..24d3c05bc3505c 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.h +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.h @@ -10,8 +10,8 @@ namespace torch::inductor { // Regarding a aten operation implemented by AOTI, the metadata of the input -// tensors will be cached on the disk to acclerate next run. TensorMetada -// structure is to represent the metadata of each input tensor. it includes +// tensors will be cached on the disk to accelerate next run. TensorMetada +// structure is to represent the metadata of each input tensor. It includes // whether the tensor is symbolic, the dtype, the device, the sizes and the // strides of the tensor. When the metadata of the input tensors is the same as // the cached metadata, the cached kernel library will be loaded and executed. @@ -51,7 +51,6 @@ struct TensorMetadata { TensorMetadata() : is_symbolic_(false), - dtype_(c10::ScalarType::Undefined), device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES), sizes_({}), strides_({}) {} @@ -116,7 +115,7 @@ struct ParameterMetadata { // same tag. For example, an operation with two input tensors, the first // tensor is a optional tensor and the second tensor is a tensor. The first // tensor will have the order 0 and the second tensor will have the order 1. - uint64_t order_; + uint64_t order_{}; ParameterMetadata() : tag_(INVALID) {} ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order); diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp new file mode 100644 index 00000000000000..5f758787e658fa --- /dev/null +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -0,0 +1,411 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) + +#include +#include +#include + +#include +#include +#include +#include +#include + +// TODO: Investigate why this is necessary, but fixes build problems in FRL +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +#ifndef _WIN32 +#include +#endif + +// TODO: C++17 has the filesystem header, which may replace these +#ifdef _WIN32 +// On Windows, the POSIX implementations are considered deprecated. We simply +// map to the newer variant. +#include +#include +#include +#define access _access +#define F_OK 0 +#else +#include +#include +#endif + +namespace { +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} + +std::string create_temp_dir() { +#ifdef _WIN32 + throw std::runtime_error("Not implemented"); +#else + std::string temp_dir = "/tmp/XXXXXX"; + if (mkdtemp(temp_dir.data()) == nullptr) { + throw std::runtime_error( + std::string("Failed to create temporary directory: ") + + strerror(errno)); + } + return temp_dir; +#endif +} +} // namespace + +namespace torch::inductor { + +namespace { +const nlohmann::json& load_json_file(std::string json_path) { + if (!file_exists(json_path)) { + throw std::runtime_error("File found: " + json_path); + } + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open()); + static nlohmann::json json_obj; + json_file >> json_obj; + + return json_obj; +} + +std::tuple get_cpp_compile_command( + const std::string& filename, + const std::vector& sources, + const nlohmann::json& compile_options, + const std::string& output_dir = "") { + // Construct the cpp command + + std::string compiler = compile_options["compiler"].get(); + bool compile_only = compile_options["compile_only"].get(); + + std::string source_args = ""; + for (const std::string& source : sources) { + source_args += source + " "; + } + + std::string file_ext = compile_only ? ".o" : ".so"; + std::string target_file = output_dir + filename + file_ext; + + std::string cflags_args = ""; + for (auto& arg : compile_options["cflags"]) { + cflags_args += "-" + arg.get() + " "; + } + + std::string definitions_args = ""; + for (auto& arg : compile_options["definitions"]) { + definitions_args += "-D " + arg.get() + " "; + } + + std::string include_dirs_args = ""; + for (auto& arg : compile_options["include_dirs"]) { + include_dirs_args += "-I" + arg.get() + " "; + } + + std::string ldflags_args = ""; + for (auto& arg : compile_options["ldflags"]) { + ldflags_args += "-" + arg.get() + " "; + } + + std::string libraries_dirs_args = ""; + for (auto& arg : compile_options["libraries_dirs"]) { + libraries_dirs_args += "-L" + arg.get() + " "; + } + + std::string libraries_args = ""; + for (auto& arg : compile_options["libraries"]) { + libraries_args += "-l" + arg.get() + " "; + } + + std::string passthrough_parameters_args = ""; + for (auto& arg : compile_options["passthrough_args"]) { + passthrough_parameters_args += arg.get() + " "; + } + + std::string compile_only_arg = compile_only ? "-c" : ""; + + std::string cmd = fmt::format( + "{} {} {} {} {} {} {} {} {} {} -o {}", + compiler, + source_args, + definitions_args, + cflags_args, + include_dirs_args, + passthrough_parameters_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + compile_only_arg, + target_file); + + return std::make_tuple(cmd, target_file); +} + +bool recursive_mkdir(const std::string& dir) { + // Creates directories recursively, copied from jit_utils.cpp + // Check if current dir exists + const char* p_dir = dir.c_str(); + const bool dir_exists = (access(p_dir, F_OK) == 0); + if (dir_exists) { + return true; + } + + // Try to create current directory +#ifdef _WIN32 + int ret = _mkdir(dir.c_str()); +#else + int ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + // Success + if (ret == 0) { + return true; + } + + // Find folder separator and check if we are at the top + auto pos = dir.find_last_of("/\\"); + if (pos == std::string::npos) { + return false; + } + + // Try to create parent directory + if (!(recursive_mkdir(dir.substr(0, pos)))) { + return false; + } + + // Try to create complete path again +#ifdef _WIN32 + ret = _mkdir(dir.c_str()); +#else + ret = mkdir(dir.c_str(), S_IRWXU | S_IRWXG | S_IRWXO); +#endif + return ret == 0; +} + +std::string compile_so( + const std::string& cpp_filename, + const std::string& consts_filename) { + // Compile the cpp file into a .so + + size_t lastindex = cpp_filename.find_last_of('.'); + std::string filename = cpp_filename.substr(0, lastindex); + + std::string compile_flags_path = filename + "_compile_flags.json"; + const nlohmann::json compile_flags = load_json_file(compile_flags_path); + + auto compile_result = + get_cpp_compile_command(filename, {cpp_filename}, compile_flags); + std::string compile_cmd = std::get<0>(compile_result); + std::string output_o = std::get<1>(compile_result); + + std::string linker_flags_path = + cpp_filename.substr(0, lastindex) + "_linker_flags.json"; + const nlohmann::json linker_flags = load_json_file(linker_flags_path); + + auto link_result = get_cpp_compile_command( + filename, {output_o, consts_filename}, linker_flags); + std::string link_cmd = std::get<0>(link_result); + std::string output_so = std::get<1>(link_result); + + // Run the commands to generate a .so file + int status = system(compile_cmd.c_str()); + if (status != 0) { + throw std::runtime_error("Failed to compile cpp file."); + } + status = system(link_cmd.c_str()); + if (status != 0) { + throw std::runtime_error("Failed to link files."); + } + + // Move the mmapped weights onto the .so + std::string serialized_weights_path = filename + "_serialized_weights.bin"; + if (file_exists(serialized_weights_path)) { + std::ifstream serialized_weights_file( + serialized_weights_path, std::ios::binary); + if (!serialized_weights_file.is_open()) { + throw std::runtime_error("Failed to open serialized weights file"); + } + std::vector serialized_weights( + (std::istreambuf_iterator(serialized_weights_file)), + std::istreambuf_iterator()); + serialized_weights_file.close(); + + std::ofstream output_so_file(output_so, std::ios::binary | std::ios::app); + if (!output_so_file.is_open()) { + throw std::runtime_error("Failed to open output .so file"); + } + // Page align the weights + std::streampos so_size = output_so_file.tellp(); + std::vector padding(16384 - so_size % 16384, ' '); + output_so_file.write( + padding.data(), static_cast(padding.size())); + output_so_file.write( + serialized_weights.data(), + static_cast(serialized_weights.size())); + output_so_file.close(); + } + + return output_so; +} +} // namespace + +void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { + // Parse metadata json file (if it exists) into the metadata_ map + size_t lastindex = cpp_filename.find_last_of('.'); + std::string metadata_json_path = + cpp_filename.substr(0, lastindex) + "_metadata.json"; + + const nlohmann::json metadata_json_obj = load_json_file(metadata_json_path); + + for (auto& item : metadata_json_obj.items()) { + metadata_[item.key()] = item.value().get(); + } +} + +AOTIModelPackageLoader::AOTIModelPackageLoader( + const std::string& model_package_path) + : AOTIModelPackageLoader(model_package_path, "model") {} + +AOTIModelPackageLoader::AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name = "model") { + // Extract all files within the zipfile to a temporary directory + mz_zip_archive zip_archive; + memset(&zip_archive, 0, sizeof(zip_archive)); + + if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { + throw std::runtime_error( + std::string("Failed to initialize zip archive: ") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + std::string temp_dir = create_temp_dir(); + std::string so_filename = ""; + std::string cpp_filename = ""; + std::string consts_filename = ""; + std::string found_filenames = ""; // Saving for bookkeeping + + for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { + uint32_t filename_len = + mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); + if (filename_len == 0) { + throw std::runtime_error("Failed to read filename"); + } + char* filename = new char[filename_len + 1]; + if (!mz_zip_reader_get_filename(&zip_archive, i, filename, filename_len)) { + throw std::runtime_error("Failed to read filename"); + } + + std::string filename_str(filename); + found_filenames += filename_str; + found_filenames += " "; + + // Only compile files in the specified model directory + std::string model_directory = "data/aotinductor/" + model_name; + if (filename_str.length() >= model_directory.length() && + filename_str.substr(0, model_directory.length()) == model_directory) { + std::string output_path_str = temp_dir; + output_path_str += "/"; + output_path_str += filename_str; + + // Create the parent directory if it doesn't exist + size_t parent_path_idx = output_path_str.find_last_of("/\\"); + if (parent_path_idx == std::string::npos) { + throw std::runtime_error( + "Failed to find parent path in " + output_path_str); + } + std::string parent_path = output_path_str.substr(0, parent_path_idx); + if (!recursive_mkdir(parent_path.c_str())) { + throw std::runtime_error(fmt::format( + "Failed to create directory {}: {}", parent_path, strerror(errno))); + } + + // Extracts file to the temp directory + mz_zip_reader_extract_file_to_file( + &zip_archive, filename, output_path_str.c_str(), 0); + + // Save the file for bookkeeping + size_t extension_idx = output_path_str.find_last_of('.'); + if (extension_idx != std::string::npos) { + std::string filename_extension = output_path_str.substr(extension_idx); + if (filename_extension == ".cpp") { + cpp_filename = output_path_str; + } + if (filename_extension == ".o") { + consts_filename = output_path_str; + } + if (filename_extension == ".so") { + so_filename = output_path_str; + } + } + } + } + + // Close the zip archive as we have extracted all files to the temp directory + if (!mz_zip_reader_end(&zip_archive)) { + throw std::runtime_error( + std::string("Failed to close zip archive: {}") + + mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); + } + + if (cpp_filename.empty() && so_filename.empty()) { + throw std::runtime_error( + "No AOTInductor generate cpp file or so file found in zip archive. Loaded the following:\n" + + found_filenames); + } + + // Compile the .so + std::string so_path = !so_filename.empty() + ? so_filename + : compile_so(cpp_filename, consts_filename); + + // Load metadata which can be queried by user + load_metadata(cpp_filename); + + // Construct the runner depending on the device information + std::string device = metadata_["AOTI_DEVICE_KEY"]; + + if (device.empty()) { + throw std::runtime_error("No device information found."); + } + + std::unordered_map + registered_aoti_runner = getAOTIModelRunnerRegistry(); + + if (registered_aoti_runner.find(device) == registered_aoti_runner.end()) { + throw std::runtime_error("Unsupported device found: " + device); + } + + runner_ = registered_aoti_runner[device](so_path, 1, device, ""); + + std::remove(temp_dir.c_str()); +} + +AOTIModelContainerRunner* AOTIModelPackageLoader::get_runner() { + return runner_.get(); +} + +std::vector AOTIModelPackageLoader::run( + std::vector& inputs) { + return runner_->run(inputs); +} + +std::unordered_map AOTIModelPackageLoader:: + get_metadata() { + return metadata_; +} + +std::vector AOTIModelPackageLoader::get_call_spec() { + return runner_->get_call_spec(); +} + +} // namespace torch::inductor +#endif diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.h b/torch/csrc/inductor/aoti_package/model_package_loader.h new file mode 100644 index 00000000000000..70c9514849da5b --- /dev/null +++ b/torch/csrc/inductor/aoti_package/model_package_loader.h @@ -0,0 +1,28 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { +class TORCH_API AOTIModelPackageLoader { + public: + AOTIModelPackageLoader(const std::string& model_package_path); + AOTIModelPackageLoader( + const std::string& model_package_path, + const std::string& model_name); + + AOTIModelContainerRunner* get_runner(); + std::unordered_map get_metadata(); + std::vector run(std::vector& inputs); + std::vector get_call_spec(); + + private: + std::unique_ptr runner_; + std::unordered_map metadata_; + + void load_metadata(const std::string& cpp_filename); +}; + +} // namespace torch::inductor +#endif diff --git a/torch/csrc/inductor/aoti_package/pybind.cpp b/torch/csrc/inductor/aoti_package/pybind.cpp new file mode 100644 index 00000000000000..3d2154a7549327 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/pybind.cpp @@ -0,0 +1,24 @@ +#include +#include +#include +#ifdef USE_CUDA +#include +#endif + +#include +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module) { + auto rootModule = py::handle(module).cast(); + auto m = rootModule.def_submodule("_aoti"); + + py::class_(m, "AOTIModelPackageLoader") + .def(py::init()) + .def(py::init()) + .def("get_metadata", &AOTIModelPackageLoader::get_metadata) + .def("run", &AOTIModelPackageLoader::run) + .def("get_call_spec", &AOTIModelPackageLoader::get_call_spec); +} +} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_package/pybind.h b/torch/csrc/inductor/aoti_package/pybind.h new file mode 100644 index 00000000000000..1eb7818c00e906 --- /dev/null +++ b/torch/csrc/inductor/aoti_package/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject* module); + +} // namespace torch::inductor diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index 4941e4e3210a45..37c69ccfa813d0 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -14,6 +14,21 @@ namespace fs = std::filesystem; namespace fs = std::experimental::filesystem; #endif +#ifndef _WIN32 +#include +#endif + +namespace { +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} +} // namespace + namespace torch::inductor { AOTIModelContainerRunner::AOTIModelContainerRunner( @@ -57,10 +72,10 @@ AOTIModelContainerRunner::AOTIModelContainerRunner( model_so_->sym("AOTInductorModelContainerGetCallSpec")); // Hack to find the json file name from the model so file - size_t lastindex = model_so_path.find_last_of("."); + size_t lastindex = model_so_path.find_last_of('.'); std::string json_filename = model_so_path.substr(0, lastindex) + ".json"; - if (fs::exists(json_filename)) { + if (file_exists(json_filename)) { proxy_executor_ = std::make_unique( json_filename, device_str == "cpu"); proxy_executor_handle_ = @@ -174,8 +189,8 @@ void AOTIModelContainerRunner::swap_constant_buffer() { } std::vector AOTIModelContainerRunner::get_call_spec() { - const char* in_spec; - const char* out_spec; + const char* in_spec = nullptr; + const char* out_spec = nullptr; AOTI_RUNTIME_ERROR_CODE_CHECK( get_call_spec_func_(container_handle_, &in_spec, &out_spec)); return {in_spec, out_spec}; diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.h b/torch/csrc/inductor/aoti_runner/model_container_runner.h index 99669aec14ac09..6e6339d3dd273c 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.h +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.h @@ -82,7 +82,7 @@ class TORCH_API AOTIModelContainerRunner { std::unique_ptr proxy_executor_; }; -using CreateAOTIModelRunnerFunc = std::shared_ptr (*)( +using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( const std::string& model_so_path, size_t num_models, const std::string& device_str, diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp index 40eb7407ecd317..f40545d04c4935 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.cpp @@ -10,12 +10,28 @@ AOTIModelContainerRunnerCpu::AOTIModelContainerRunnerCpu( size_t num_models) : AOTIModelContainerRunner(model_so_path, num_models, "cpu", "") {} -AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() {} +AOTIModelContainerRunnerCpu::~AOTIModelContainerRunnerCpu() = default; std::vector AOTIModelContainerRunnerCpu::run( std::vector& inputs) { return AOTIModelContainerRunner::run(inputs); } +namespace { +std::unique_ptr create_aoti_runner_cpu( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + if (device_str != "cpu") { + throw std::runtime_error("Incorrect device passed to aoti_runner_cpu"); + } + return std::make_unique( + model_so_path, num_models); +} +} // namespace + +RegisterAOTIModelRunner register_cpu_runner("cpu", &create_aoti_runner_cpu); + } // namespace torch::inductor #endif diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp index 705a59eb3f394f..3ddad0885aa53d 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp @@ -14,7 +14,7 @@ AOTIModelContainerRunnerCuda::AOTIModelContainerRunnerCuda( device_str, cubin_dir) {} -AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() {} +AOTIModelContainerRunnerCuda::~AOTIModelContainerRunnerCuda() = default; std::vector AOTIModelContainerRunnerCuda::run( std::vector& inputs) { @@ -30,5 +30,18 @@ std::vector AOTIModelContainerRunnerCuda::run_with_cuda_stream( inputs, reinterpret_cast(cuda_stream.stream())); } +namespace { +std::unique_ptr create_aoti_runner_cuda( + const std::string& model_so_path, + size_t num_models, + const std::string& device_str, + const std::string& cubin_dir) { + return std::make_unique( + model_so_path, num_models, device_str, cubin_dir); +} +} // namespace + +RegisterAOTIModelRunner register_cuda_runner("cuda", &create_aoti_runner_cuda); + } // namespace torch::inductor #endif diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 5c3869c9a4f953..b470b5f10061b8 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -51,6 +51,8 @@ #endif // _WIN32 #endif // __GNUC__ +// The following files are implemented in a header-only way and are guarded by +// test/cpp/aoti_abi_check #include #include #include @@ -493,7 +495,7 @@ aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16( // This will soon be deprecated after ao_quantization is complete. // Please refrain from using this or increasing callsites. -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_wrapped_linear_prepack( +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AtenTensorHandle weight, AtenTensorHandle weight_scale, AtenTensorHandle weight_zero_point, @@ -513,7 +515,7 @@ aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( // This will soon be deprecated after ao_quantization is complete. // Please refrain from using this or increasing callsites. AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_cpu_wrapped_quantized_linear_prepacked( +aoti_torch_cpu__wrapped_quantized_linear_prepacked( AtenTensorHandle input, AtenTensorHandle input_scale, AtenTensorHandle input_zero_point, @@ -573,6 +575,14 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg); +// When AOTI debug printer option is enabled, this function will be invoked to +// torch pickle save the intermediate tensor for debugging purpose. +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( + AtenTensorHandle self, + const char* tensor_name, + const char* launch_prefix, + const char* kernel_name); + #ifdef USE_CUDA struct CUDAGuardOpaque; @@ -604,7 +614,7 @@ aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); -#endif +#endif // USE_CUDA // See `ProxyExecutor Design Note` in ir.py for more details AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h b/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h new file mode 100644 index 00000000000000..63b35f62b3dc34 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h @@ -0,0 +1,39 @@ +#ifndef AOTI_TORCH_SHIM_MKLDNN +#define AOTI_TORCH_SHIM_MKLDNN + +#include +#include + +#if AT_MKLDNN_ENABLED() +#ifdef __cplusplus +extern "C" { +#endif + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_mkldnn_rnn_layer( + AtenTensorHandle input, + AtenTensorHandle weight0, + AtenTensorHandle weight1, + AtenTensorHandle weight2, + AtenTensorHandle weight3, + AtenTensorHandle hx_, + AtenTensorHandle cx_, + int32_t reverse, + const int64_t* batch_sizes, + int64_t batch_sizes_len_, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + int32_t has_biases, + int32_t bidirectional, + int32_t batch_first, + int32_t train, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1, + AtenTensorHandle* ret2, + AtenTensorHandle* ret3); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // AT_MKLDNN_ENABLED() +#endif // AOTI_TORCH_SHIM_MKLDNN diff --git a/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp index 7f0811f0d88b50..23f15043241788 100644 --- a/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp +++ b/torch/csrc/inductor/aoti_torch/mkldnn_tensor.cpp @@ -6,8 +6,7 @@ #include #endif -namespace torch { -namespace aot_inductor { +namespace torch::aot_inductor { #if AT_MKLDNN_ENABLED() @@ -45,5 +44,4 @@ at::Tensor mkldnn_tensor_from_data_ptr( #endif -} // namespace aot_inductor -} // namespace torch +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 8e9c1a4b18a3e8..f49bf23b9ce422 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -10,8 +10,10 @@ #include #include #include +#include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -24,6 +26,8 @@ #include #include #include +#include +#include #include #include #include @@ -40,11 +44,50 @@ #include #include #include -#include -#include #endif +#if __has_include("filesystem") +#include +namespace fs = std::filesystem; +#else +#include +namespace fs = std::experimental::filesystem; +#endif + +#ifndef _WIN32 +#include +#endif + +// HACK for failed builds in ARVR, where it cannot find these symbols within +// std::experimental::filesystem +namespace { +fs::path get_current_path() { +#if __has_include("filesystem") + return fs::current_path(); +#else + throw std::runtime_error("Not implemented"); +#endif +} + +bool file_exists(std::string& path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc; + return lstat(path.c_str(), &rc) == 0; +#endif +} + +bool create_directories(const std::string& path) { +#if __has_include("filesystem") + return fs::create_directories(path); +#else + throw std::runtime_error("Not implemented"); +#endif +} +} // namespace + using namespace torch::aot_inductor; namespace { @@ -804,7 +847,7 @@ AOTITorchError aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16( }); } -AOTITorchError aoti_torch_cpu_wrapped_linear_prepack( +AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AtenTensorHandle weight, AtenTensorHandle weight_scale, AtenTensorHandle weight_zero_point, @@ -818,7 +861,7 @@ AOTITorchError aoti_torch_cpu_wrapped_linear_prepack( tensor_handle_to_tensor_pointer(weight_zero_point); at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias); - *out = new_tensor_handle(at::wrapped_linear_prepack( + *out = new_tensor_handle(at::_wrapped_linear_prepack( *weight_tensor, *weight_scale_tensor, *weight_zero_point_tensor, @@ -842,7 +885,7 @@ AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( }); } -AOTITorchError aoti_torch_cpu_wrapped_quantized_linear_prepacked( +AOTITorchError aoti_torch_cpu__wrapped_quantized_linear_prepacked( AtenTensorHandle input, AtenTensorHandle input_scale, AtenTensorHandle input_zero_point, @@ -861,7 +904,7 @@ AOTITorchError aoti_torch_cpu_wrapped_quantized_linear_prepacked( at::Tensor* out_scale_tensor = tensor_handle_to_tensor_pointer(out_scale); at::Tensor* out_zeropoint_tensor = tensor_handle_to_tensor_pointer(out_zeropoint); - *out = new_tensor_handle(at::wrapped_quantized_linear_prepacked( + *out = new_tensor_handle(at::_wrapped_quantized_linear_prepacked( *input_tensor, *input_scale_tensor, *input_zero_point_tensor, @@ -986,13 +1029,45 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( }); } +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( + AtenTensorHandle self, + const char* tensor_name, + const char* launch_prefix, + const char* kernel_name) { + at::Tensor* t = tensor_handle_to_tensor_pointer(self); +#ifndef C10_MOBILE + // Save tensor to tmp .pt file for tensors and can be torch.load'ed later + std::string cwd = get_current_path().string(); + std::string tmp_folder = cwd + "/tmp/aoti_torch/"; + if (!file_exists(tmp_folder)) { + std::cout + << "aoti_torch_save_tensor_handle: Path does not exist, creating it..." + << tmp_folder << std::endl; + + if (!create_directories(tmp_folder)) { + std::cout << "aoti_torch_save_tensor_handle: Error creating directory: " + << tmp_folder << std::endl; + return; + } + } + std::string tensor_filepath_to_save = tmp_folder + launch_prefix + "_" + + kernel_name + "_" + tensor_name + "_" + t->device().str() + ".pt"; + + auto bytes = torch::jit::pickle_save(c10::IValue(*t)); + std::ofstream fout(tensor_filepath_to_save, std::ios::out | std::ios::binary); + fout.write(bytes.data(), bytes.size()); + fout.close(); + + std::cout << "aoti_torch_save_tensor_handle: Saved tensor to " + << tensor_filepath_to_save << std::endl; +#endif // !defined(C10_MOBILE) +} + AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg) { at::Tensor* t = tensor_handle_to_tensor_pointer(self); - auto device = t->device(); - // Display message std::cout << "["; if (msg) { @@ -1009,15 +1084,31 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( // Print summary stats of the tensor std::cout << "Number of elements: " << numel << std::endl; + std::cout << "Dtype: " << t->dtype() << std::endl; if (numel > 0) { - std::cout << "Mean value: " << t->mean().item() << std::endl; - std::cout << "Min value: " << t->min().item() << std::endl; - std::cout << "Max value: " << t->max().item() << std::endl; + // torch/aten `mean()` function only supports float and complex dtypes + // See: + // https://github.com/pytorch/pytorch/blob/a0e062c6f1a03ec93e87413e42c4d0b336518131/aten/src/ATen/native/ReduceOps.cpp#L304-L309 + auto mean_value = [t](at::ScalarType dtype) { + return t->to(dtype).mean().item(); + }; + bool is_complex_type = + at::isComplexType(at::typeMetaToScalarType(t->dtype())); + at::ScalarType float_dtype = + is_complex_type ? at::kComplexFloat : at::kFloat; + std::cout << "Mean value: " << mean_value(float_dtype) << std::endl; + if (!is_complex_type) { + // "min_all_cuda" function is not implemented for 'ComplexFloat' type. + // (similar for max) Skip printing min/max value for complex type tensors + // here If encountered complex dtypes (rare occasions), suggest to print + // out the whole value of the tensor. + std::cout << "Min value: " << t->min().item() << std::endl; + std::cout << "Max value: " << t->max().item() << std::endl; + } } - std::cout << "Device: " << device << std::endl; + std::cout << "Device: " << t->device() << std::endl; std::cout << "Size: " << t->sizes() << std::endl; std::cout << "Stride: " << t->strides() << std::endl; - std::cout << "Dtype: " << t->dtype() << std::endl; std::cout << "Layout: " << t->layout() << std::endl; std::cout << "Is contiguous: " << t->is_contiguous() << std::endl; std::cout << "Requires grad: " << t->requires_grad() << std::endl; diff --git a/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp b/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp new file mode 100644 index 00000000000000..14f9fbf69459e6 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp @@ -0,0 +1,62 @@ + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +using namespace torch::aot_inductor; + +#if AT_MKLDNN_ENABLED() + +AOTITorchError aoti_torch_cpu_mkldnn_rnn_layer( + AtenTensorHandle input, + AtenTensorHandle weight0, + AtenTensorHandle weight1, + AtenTensorHandle weight2, + AtenTensorHandle weight3, + AtenTensorHandle hx_, + AtenTensorHandle cx_, + int32_t reverse, + const int64_t* batch_sizes, + int64_t batch_sizes_len_, + int64_t mode, + int64_t hidden_size, + int64_t num_layers, + int32_t has_biases, + int32_t bidirectional, + int32_t batch_first, + int32_t train, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1, + AtenTensorHandle* ret2, + AtenTensorHandle* ret3) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = at::cpu::mkldnn_rnn_layer( + *tensor_handle_to_tensor_pointer(input), + *tensor_handle_to_tensor_pointer(weight0), + *tensor_handle_to_tensor_pointer(weight1), + *tensor_handle_to_tensor_pointer(weight2), + *tensor_handle_to_tensor_pointer(weight3), + *tensor_handle_to_tensor_pointer(hx_), + *tensor_handle_to_tensor_pointer(cx_), + reverse, + pointer_to_list(batch_sizes, batch_sizes_len_), + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train); + *ret0 = new_tensor_handle(std::move(std::get<0>(tmp_result))); + *ret1 = new_tensor_handle(std::move(std::get<1>(tmp_result))); + *ret2 = new_tensor_handle(std::move(std::get<2>(tmp_result))); + *ret3 = new_tensor_handle(std::move(std::get<3>(tmp_result))); + }); +} + +#endif // AT_MKLDNN_ENABLED() diff --git a/torch/csrc/inductor/aoti_torch/tensor_converter.cpp b/torch/csrc/inductor/aoti_torch/tensor_converter.cpp index d61fa20bab8789..b53a1d8811d811 100644 --- a/torch/csrc/inductor/aoti_torch/tensor_converter.cpp +++ b/torch/csrc/inductor/aoti_torch/tensor_converter.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace aot_inductor { +namespace torch::aot_inductor { std::vector unsafe_alloc_new_handles_from_tensors( std::vector& tensors) { @@ -45,5 +44,4 @@ std::vector alloc_tensors_by_stealing_from_handles( return result; } -} // namespace aot_inductor -} // namespace torch +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/inductor_ops.cpp b/torch/csrc/inductor/inductor_ops.cpp index ec9205f130249a..7d0e9b612343b8 100644 --- a/torch/csrc/inductor/inductor_ops.cpp +++ b/torch/csrc/inductor/inductor_ops.cpp @@ -10,8 +10,7 @@ #include -namespace torch { -namespace inductor { +namespace torch::inductor { using namespace at; Tensor _mm_plus_mm_out( @@ -111,5 +110,4 @@ TORCH_LIBRARY_FRAGMENT(inductor, m) { {at::Tag::pt2_compliant_tag}); } -} // namespace inductor -} // namespace torch +} // namespace torch::inductor diff --git a/torch/csrc/inductor/resize_storage_bytes.cpp b/torch/csrc/inductor/resize_storage_bytes.cpp index 94522b7df8c82d..018acb1a0fc5a2 100644 --- a/torch/csrc/inductor/resize_storage_bytes.cpp +++ b/torch/csrc/inductor/resize_storage_bytes.cpp @@ -7,8 +7,7 @@ #include #endif -namespace torch { -namespace inductor { +namespace torch::inductor { using namespace at; // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -63,5 +62,4 @@ TORCH_LIBRARY_IMPL(inductor, Functionalize, m) { "resize_storage_bytes_", TORCH_FN(resize_storage_bytes__functionalize)); } -} // namespace inductor -} // namespace torch +} // namespace torch::inductor diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h index 01e7a3c98e3024..b5d336db2b6ec2 100644 --- a/torch/csrc/jit/api/function_impl.h +++ b/torch/csrc/jit/api/function_impl.h @@ -7,7 +7,6 @@ namespace torch::jit { struct TORCH_API GraphFunction : public Function { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GraphFunction( c10::QualifiedName name, std::shared_ptr graph, diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 0dd9c76d5eb6ac..a44eccb601ba32 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -564,7 +564,6 @@ std::string Module::dump_to_str( std::stringstream parameters_ss; std::stringstream attributes_ss; std::stringstream methods_ss; - std::stringstream submodules_ss; for (const NameTensor& p : named_parameters(/*recurse=*/false)) { parameters_ss << p.name << " = "; diff --git a/torch/csrc/jit/backends/backend_detail.cpp b/torch/csrc/jit/backends/backend_detail.cpp index f6cb7b768edcec..de352f50ab503b 100644 --- a/torch/csrc/jit/backends/backend_detail.cpp +++ b/torch/csrc/jit/backends/backend_detail.cpp @@ -11,9 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace detail { +namespace torch::jit::detail { namespace { /* @@ -361,7 +359,7 @@ Module codegen_backend_module( wrapper_method_te.v("def_inputs", def_inputs); wrapper_method_te.v("fwd_inputs", fwd_inputs); - wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te)); + wrapper_methods.emplace_back(wrapper_method_ct.format(wrapper_method_te)); // If the output type is a single element tuple then add an extra comma // to ensure the final output maintains this type. @@ -408,6 +406,4 @@ Module codegen_backend_module( return wrapper; } -} // namespace detail -} // namespace jit -} // namespace torch +} // namespace torch::jit::detail diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp index ba4a2b25c23a78..a5a331d15c21c7 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { // Implementation of Android NNAPI Backend delegate @@ -107,7 +106,7 @@ class NnapiBackend : public PyTorchBackendInterface { // Runs once per model initialization // Cannot be moved to compile(), because init() requires actual inputs - void init(c10::IValue handle, c10::List inputs) { + void init(const c10::IValue& handle, const c10::List& inputs) { TORCH_CHECK(comp_ == nullptr); auto dict = handle.toGenericDict(); @@ -134,5 +133,4 @@ constexpr auto backend_name = "nnapi"; static auto cls = torch::jit::backend(backend_name); } // namespace -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/codegen/fuser/arg_spec.h b/torch/csrc/jit/codegen/fuser/arg_spec.h index 06a72c231b84a5..7239e0391b8fce 100644 --- a/torch/csrc/jit/codegen/fuser/arg_spec.h +++ b/torch/csrc/jit/codegen/fuser/arg_spec.h @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Describes the (runtime) arguments to a kernel. // ArgSpecs are also used as keys to lookup instantiated kernels, so @@ -55,6 +53,4 @@ struct TORCH_API ArgSpec { int device_; }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/codegen.cpp b/torch/csrc/jit/codegen/fuser/codegen.cpp index 940444f4ce7edb..04700b905f41bd 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.cpp +++ b/torch/csrc/jit/codegen/fuser/codegen.cpp @@ -18,9 +18,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Template for computing the offset into the tensor to access a value static auto dim_calc = at::jit::CodeTemplate(R"( @@ -538,7 +536,7 @@ std::string generateKernel( // places where the constant None node is used // Note: No need to iterate over reference as n is a pointer for (const auto n : graph.nodes()) { - static_assert(std::is_pointer::value, "n must be a pointer"); + static_assert(std::is_pointer_v, "n must be a pointer"); // Note: FusedConcat nodes work by narrowing the output Tensors before the // kernel runs if (n->kind() == prim::FusedConcat) @@ -680,11 +678,9 @@ std::string generateKernel( } if (debugFuser()) { - std::cerr << "fusion code:" << code_string << std::endl; + std::cerr << "fusion code:" << code_string << '\n'; } return code_string; } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/codegen.h b/torch/csrc/jit/codegen/fuser/codegen.h index e42adc1314320e..cab5ccf0eb5e6f 100644 --- a/torch/csrc/jit/codegen/fuser/codegen.h +++ b/torch/csrc/jit/codegen/fuser/codegen.h @@ -9,9 +9,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Creates a CPU or CUDA kernel for the given graph. // Returns the C++ or CUDA string implementing the kernel. @@ -23,6 +21,4 @@ TORCH_API std::string generateKernel( const std::vector>& outputs, const bool use_cuda); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 7e03b576d12184..3944e6f2df82f3 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -30,9 +30,7 @@ std::mutex& fusionBackendLock() { } } // namespace -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { static std::unordered_map& getFusionBackends() { @@ -203,7 +201,7 @@ std::shared_ptr compileKernel( const KernelSpec& spec, const ArgSpec& arg_spec, const std::vector& map_size, - const at::Device device) { + const at::Device& device) { const std::vector& input_desc = arg_spec.descs(); auto graph = spec.graph()->copy(); @@ -297,6 +295,4 @@ std::shared_ptr compileKernel( spec.hasRandom()); } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/compiler.h b/torch/csrc/jit/codegen/fuser/compiler.h index 3d490f6940ab0d..d9c0005b9d2858 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.h +++ b/torch/csrc/jit/codegen/fuser/compiler.h @@ -11,9 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Performs device-independent "upfront" compilation of the given fusion_group, // if it has not been registered already. @@ -27,7 +25,7 @@ TORCH_API std::shared_ptr compileKernel( const KernelSpec& spec, const ArgSpec& arg_spec, const std::vector& map_size, - const at::Device device); + const at::Device& device); TORCH_API size_t nCompiledKernels(); @@ -55,6 +53,4 @@ struct TORCH_API RegisterFusionBackend { } }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp index db9d57a679cb15..09624309d16cd1 100644 --- a/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp @@ -11,10 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cpu { +namespace torch::jit::fuser::cpu { #ifdef _MSC_VER static const std::string getTempPath() { @@ -357,7 +354,4 @@ static std::shared_ptr createFusionKernel( } RegisterFusionBackend reg(DeviceType::CPU, createFusionKernel); -} // namespace cpu -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cpu diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index d17da0cbaa21d4..fd9fca230b9668 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -16,13 +16,9 @@ #include #include #include -#include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { +namespace torch::jit::fuser::cuda { // See NOTE [ USE OF NVRTC AND DRIVER API ] const at::cuda::NVRTC& nvrtc() { @@ -85,7 +81,6 @@ void codegenOutputQuery( } // Compiles the specified kernel and stores the metadata required to run it -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FusedKernelCUDA::FusedKernelCUDA( at::DeviceIndex device, std::string name, @@ -114,15 +109,14 @@ FusedKernelCUDA::FusedKernelCUDA( // Acquires device and NVRTC properties (for compile arch and occupancy // calculations) + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) prop_ = at::cuda::getCurrentDeviceProperties(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int major, minor; + int major = 0, minor = 0; bool compile_to_sass = false; codegenOutputQuery(prop_, major, minor, compile_to_sass); // Creates the NVRTC program - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - nvrtcProgram program; + nvrtcProgram program{}; AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( &program, code_.c_str(), nullptr, 0, nullptr, nullptr)); @@ -144,17 +138,14 @@ FusedKernelCUDA::FusedKernelCUDA( "compute_" + #endif std::to_string(major) + std::to_string(minor); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) const std::vector args = { "--std=c++17", compute.c_str(), "-default-device"}; #endif const auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); if (result != NVRTC_SUCCESS) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t logsize; + size_t logsize = 0; AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize)); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector log(logsize); AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data())); std::stringstream cu; @@ -164,8 +155,7 @@ FusedKernelCUDA::FusedKernelCUDA( ResourceGuard holdProgram( [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); AT_CUDA_NVRTC_CHECK(result); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t ptx_size; + size_t ptx_size = 0; #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // compile_to_sass determines whether we are generating SASS or PTX, hence // the different API. @@ -203,8 +193,7 @@ static int ceilDiv(const int a, const int b) { void FusedKernelCUDA::launch_raw( const uint32_t numel, std::vector& arguments) const { - // NOLINTNEXTLINE(bugprone-unused-raii) - at::cuda::CUDAGuard{device_}; + at::cuda::CUDAGuard guard{device_}; // Hacked at::DeviceGuard (see note above) const auto prior_device = at::cuda::current_device(); at::cuda::set_device(device_); @@ -275,7 +264,4 @@ static std::shared_ptr createFusionKernel( RegisterFusionBackend reg(DeviceType::CUDA, createFusionKernel); -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h index 4d39660532e358..d635049e758a2c 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -12,10 +11,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { +namespace torch::jit::fuser::cuda { // query codegen output arch and target TORCH_CUDA_CU_API void codegenOutputQuery( @@ -53,14 +49,11 @@ struct TORCH_CUDA_CU_API FusedKernelCUDA // Note: per device to store device properties and compute launch heuristics // Acquiring these values at launch time would be too slow at::DeviceIndex device_; - int maxBlocks_; - cudaDeviceProp* prop_; + int maxBlocks_{}; + cudaDeviceProp* prop_{}; std::vector ptx_; - CUmodule module_; - CUfunction function_; + CUmodule module_{}; + CUfunction function_{}; }; -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index e6114f818e3318..ff2ef1f2377ce1 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -3,10 +3,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { +namespace torch::jit::fuser::cuda { /*with type_as not checking type of its input, a fusion group can have non-fp32 tensor as input. Correct code for this case is generated, however, nvrtc does @@ -405,7 +402,4 @@ __device__ float __bfloat162float(const __nv_bfloat16 a) { )"; #endif -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser::cuda diff --git a/torch/csrc/jit/codegen/fuser/executor.cpp b/torch/csrc/jit/codegen/fuser/executor.cpp index 411dbe62a2e157..fa104b7cc16bff 100644 --- a/torch/csrc/jit/codegen/fuser/executor.cpp +++ b/torch/csrc/jit/codegen/fuser/executor.cpp @@ -14,15 +14,9 @@ #include #include -#include // TODO: remove, debugging only -#include -#include -#include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Returns the "map size" for this run, which is the common size for all // intermediate tensors. @@ -215,8 +209,7 @@ static void launchFusion( // Computes map_size, numel from the first input at::IntArrayRef map_size; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t numel; + uint32_t numel = 0; std::vector keep_alive_size; if (fusion.chunkDesc()[0].isNoop()) { map_size = inputs[0].sizes(); @@ -409,6 +402,4 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) { return true; } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/executor.h b/torch/csrc/jit/codegen/fuser/executor.h index f454e3d6a1845d..188a25ed8cc65c 100644 --- a/torch/csrc/jit/codegen/fuser/executor.h +++ b/torch/csrc/jit/codegen/fuser/executor.h @@ -7,9 +7,7 @@ #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Runs the fusion associated with the key (see registerFusion() in interface.h) // on the inputs taken from the given Stack. @@ -18,6 +16,4 @@ TORCH_API bool runFusion( Stack& stack, std::string* code_out = nullptr); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/fallback.cpp b/torch/csrc/jit/codegen/fuser/fallback.cpp index 60a5d72f3c439c..70d90bd94c0ae1 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.cpp +++ b/torch/csrc/jit/codegen/fuser/fallback.cpp @@ -9,9 +9,7 @@ #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { namespace { c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { @@ -46,6 +44,4 @@ void runFallback(int64_t key, Stack& stack) { InterpreterState{(*maybe_spec)->code()}.run(stack); } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/fallback.h b/torch/csrc/jit/codegen/fuser/fallback.h index 570348ec53640d..af0ff32641ce65 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.h +++ b/torch/csrc/jit/codegen/fuser/fallback.h @@ -4,12 +4,8 @@ #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { void runFallback(int64_t key, Stack& stack); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/fused_kernel.h b/torch/csrc/jit/codegen/fuser/fused_kernel.h index 29ab3e7ed51c06..de00904a749c8e 100644 --- a/torch/csrc/jit/codegen/fuser/fused_kernel.h +++ b/torch/csrc/jit/codegen/fuser/fused_kernel.h @@ -9,14 +9,11 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { struct FusedKernel { AT_DISALLOW_COPY_AND_ASSIGN(FusedKernel); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FusedKernel( std::string name, std::string code, @@ -98,6 +95,4 @@ struct FusedKernel { const bool has_random_; }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/interface.cpp b/torch/csrc/jit/codegen/fuser/interface.cpp index 13f4fd57142f45..0d6f7c1facc345 100644 --- a/torch/csrc/jit/codegen/fuser/interface.cpp +++ b/torch/csrc/jit/codegen/fuser/interface.cpp @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace detail { @@ -105,5 +104,4 @@ size_t nCompiledKernels() { return fuser::nCompiledKernels(); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/codegen/fuser/interface.h b/torch/csrc/jit/codegen/fuser/interface.h index 2c67064fe1359c..977e90191160ce 100644 --- a/torch/csrc/jit/codegen/fuser/interface.h +++ b/torch/csrc/jit/codegen/fuser/interface.h @@ -9,8 +9,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { constexpr int kCPUDevice = -1; @@ -52,5 +51,4 @@ TORCH_API std::string debugGetFusedKernelCode( TORCH_API size_t nCompiledKernels(); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/codegen/fuser/kernel_cache.cpp b/torch/csrc/jit/codegen/fuser/kernel_cache.cpp index 349378af05ea4d..4d82e4db3bf609 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/fuser/kernel_cache.cpp @@ -6,9 +6,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { struct KernelCacheImpl { // Note: std::unordered_map does not invalidate references even if rehashing @@ -76,7 +74,7 @@ std::optional retrieve(const int64_t key) { } // precondition: graph has been normalized via normalizeGraphForCache -std::optional lookupGraph(std::shared_ptr graph) { +std::optional lookupGraph(const std::shared_ptr& graph) { auto& cache = getKernelCache(); std::string repr = graph->toString(false); @@ -87,6 +85,4 @@ std::optional lookupGraph(std::shared_ptr graph) { return nolock_retrieve(cache, it->second); } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/kernel_cache.h b/torch/csrc/jit/codegen/fuser/kernel_cache.h index c4f340225deb7b..370f782453b1cb 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_cache.h +++ b/torch/csrc/jit/codegen/fuser/kernel_cache.h @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // A thread-safe cache interface. @@ -22,7 +20,8 @@ TORCH_API std::shared_ptr normalizeGraphForCache( TORCH_API int64_t store(std::shared_ptr graph); // Given a graph, find a KernelSpec based on it -TORCH_API std::optional lookupGraph(std::shared_ptr graph); +TORCH_API std::optional lookupGraph( + const std::shared_ptr& graph); // Returns the graph corresponding to the given key (if it exists) TORCH_API std::optional retrieve(const int64_t key); @@ -31,6 +30,4 @@ TORCH_API std::optional retrieve(const int64_t key); // Only used for testing. TORCH_API int64_t debugNumCachedKernelSpecs(); -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/kernel_spec.h b/torch/csrc/jit/codegen/fuser/kernel_spec.h index eacdbc7ec3f336..6b1af19a1a6a61 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_spec.h +++ b/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -16,9 +16,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Helper struct containing partition information: the number of tensors // created and the dimension the partitioning is performed on. @@ -56,20 +54,19 @@ struct TORCH_API KernelSpec { // Note: assumes the spec is a single block // Note: This is the appropriate place to generalize if you want to add other // passes to upfront compilation that walk the graph. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) KernelSpec(const int64_t _key, const std::shared_ptr& _graph) : key_{_key}, graph_{_graph}, code_{_graph, ""}, nInputs_{_graph->inputs().size()}, - nTensorInputs_{}, + inputBroadcastGroups_{}, inputChunks_{}, - has_random_{false}, + kernels_{} { // No need to iterate over reference since n is pointer for (const auto n : graph_->nodes()) { - static_assert(std::is_pointer::value, "n must be a pointer"); + static_assert(std::is_pointer_v, "n must be a pointer"); if (n->kind() == aten::rand_like) { has_random_ = true; break; @@ -125,8 +122,9 @@ struct TORCH_API KernelSpec { return std::nullopt; return it->second; } - void cacheKernel(const ArgSpec& arg_spec, std::shared_ptr kernel) - const { + void cacheKernel( + const ArgSpec& arg_spec, + const std::shared_ptr& kernel) const { std::lock_guard guard{mutex_}; kernels_.emplace(arg_spec, kernel); } @@ -136,16 +134,14 @@ struct TORCH_API KernelSpec { std::shared_ptr graph_; Code code_; uint64_t nInputs_; - uint64_t nTensorInputs_; + uint64_t nTensorInputs_{}; std::vector> inputBroadcastGroups_; std::vector inputChunks_; - bool has_random_; + bool has_random_{false}; mutable std::mutex mutex_; mutable std:: unordered_map, c10::hash> kernels_; }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/partition_desc.h b/torch/csrc/jit/codegen/fuser/partition_desc.h index 87cf7986e5e91c..964e1821364a50 100644 --- a/torch/csrc/jit/codegen/fuser/partition_desc.h +++ b/torch/csrc/jit/codegen/fuser/partition_desc.h @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Descriptor for chunk-ing an input tensor into subtensors // OR concat-ing an output tensor from subtensors @@ -22,7 +20,6 @@ struct TORCH_API PartitionDesc { PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim) : nSubTensors_{_nSubTensors}, dim_{_dim} { AT_ASSERT(nSubTensors_ > 1); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector cont = _desc.contiguity; if (dim_ > 0) { // when we narrow the concatenated output/chunked input @@ -31,8 +28,7 @@ struct TORCH_API PartitionDesc { // so dim - 1 is no longer contiguous cont[dim_ - 1] = false; } - // NOLINTNEXTLINE(modernize-make-shared) - subTensorDesc_.reset(new TensorDesc(_desc.scalar_type, cont)); + subTensorDesc_ = std::make_shared(_desc.scalar_type, cont); } bool isNoop() const { @@ -59,6 +55,4 @@ struct TORCH_API PartitionDesc { subTensorDesc_; // descriptor for the subtensor, if it exists }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/tensor_desc.h b/torch/csrc/jit/codegen/fuser/tensor_desc.h index ffc405244a71ef..0c5db65d54ad1a 100644 --- a/torch/csrc/jit/codegen/fuser/tensor_desc.h +++ b/torch/csrc/jit/codegen/fuser/tensor_desc.h @@ -10,20 +10,15 @@ #include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // type information needed by the compiler for input/outputs // contiguity[i] is true if the dim i is contiguous with dim i + 1. // contiguity.back() == true means strides.back() == 1. struct TORCH_API TensorDesc { - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) at::ScalarType scalar_type; - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::vector contiguity; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDesc(const at::ScalarType& type, const std::vector& contiguity) : scalar_type{type}, contiguity{contiguity} { if (contiguity.empty()) { @@ -35,7 +30,6 @@ struct TORCH_API TensorDesc { } // Delegating constructors - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDesc( const at::ScalarType& type, const at::IntArrayRef& sizes, @@ -45,7 +39,6 @@ struct TORCH_API TensorDesc { TensorDesc(const at::Tensor& t) : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {} - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDesc(const c10::TensorTypePtr& type) : TensorDesc( type->scalarType().value(), @@ -66,7 +59,6 @@ struct TORCH_API TensorDesc { const at::IntArrayRef& sizes, const at::IntArrayRef& strides) { AT_ASSERT(sizes.size() == strides.size()); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector cont(sizes.size()); for (size_t i = 0; i < sizes.size(); ++i) { const auto expected_stride = @@ -103,6 +95,4 @@ inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) { return out; } -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/codegen/fuser/tensor_info.h b/torch/csrc/jit/codegen/fuser/tensor_info.h index 9f3fcca1aca2e5..77a0d8bacdf23b 100644 --- a/torch/csrc/jit/codegen/fuser/tensor_info.h +++ b/torch/csrc/jit/codegen/fuser/tensor_info.h @@ -1,11 +1,10 @@ #pragma once #include +#include #include -namespace torch { -namespace jit { -namespace fuser { +namespace torch::jit::fuser { // Host-side view of TensorInfo // Note dims[0] - we need to dynamically allocate the dims. @@ -18,12 +17,8 @@ struct TORCH_API TensorInfo { } void* data; -#pragma GCC diagnostic ignored "-Wpedantic" // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) uint32_t sizes_strides[0]; -#pragma GCC diagnostic pop }; -} // namespace fuser -} // namespace jit -} // namespace torch +} // namespace torch::jit::fuser diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h index eade3b23d7266c..a6a4d2e31522c1 100644 --- a/torch/csrc/jit/cuda/cuda.h +++ b/torch/csrc/jit/cuda/cuda.h @@ -12,7 +12,6 @@ class CUDAEvent; // c10/cuda/CUDAStream.h. class CUDAStream final : public CustomClassHolder { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAStream( std::optional device = std::nullopt, int64_t priority = 0) { @@ -22,7 +21,6 @@ class CUDAStream final : public CustomClassHolder { c10::cuda::getStreamFromPool(static_cast(priority), device_index)); } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAStream(c10::cuda::CUDAStream s) { stream_ = std::make_unique(s); } @@ -69,12 +67,10 @@ class CUDAStream final : public CustomClassHolder { // aten/src/ATen/cuda/CUDAEvent.h. class CUDAEvent final : public CustomClassHolder { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CUDAEvent( bool enable_timing = false, bool blocking = false, bool interprocess = false) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int flags = cudaEventDisableTiming; if (enable_timing) { flags = cudaEventDefault; @@ -95,8 +91,7 @@ class CUDAEvent final : public CustomClassHolder { } std::string ipcHandle() { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - cudaIpcEventHandle_t handle; + cudaIpcEventHandle_t handle{}; event_->ipc_handle(&handle); std::string str_handle((const char*)&handle, sizeof(handle)); return str_handle; diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index f9f317d9a43068..4d7904b8707c27 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -38,9 +38,10 @@ TypePtr ScriptTypeParser::subscriptToType( // here. See https://docs.python.org/3/library/typing.html#typing.Tuple auto tup_literal = TupleLiteral(subscript.subscript_exprs()[0]); if (!tup_literal.inputs().empty()) { - throw ErrorReport(tup_literal.range()) + throw( + ErrorReport(tup_literal.range()) << "Tuple literal in Tuple type annotation must not " - << "have any elements!"; + << "have any elements!"); } return TupleType::create({}); } @@ -179,12 +180,12 @@ std::optional> ScriptTypeParser::parseBroadcastList( TypePtr list_ptr = ListType::create(elem_ptr->second); const char* len_c = len.c_str(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char* end; + char* end = nullptr; size_t len_v = strtoull(len_c, &end, 10); if (end != len_c + len.size()) { - throw ErrorReport(subscript.subscript_exprs().range()) - << "subscript of Broadcastable list must be a positive integer"; + throw( + ErrorReport(subscript.subscript_exprs().range()) + << "subscript of Broadcastable list must be a positive integer"); } return std::pair(list_ptr, len_v); } diff --git a/torch/csrc/jit/frontend/tree_views.h b/torch/csrc/jit/frontend/tree_views.h index b4547eb85c0c5e..9f3a926fe5a908 100644 --- a/torch/csrc/jit/frontend/tree_views.h +++ b/torch/csrc/jit/frontend/tree_views.h @@ -920,13 +920,11 @@ struct Const : public Expr { double asFloatingPoint() const { // We can't pass in nullptr as the dummy pointer gets dereferenced for // Android version of strtod_c(). - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char* dummy; + char* dummy = nullptr; return torch::jit::strtod_c(subtree(0)->stringValue().c_str(), &dummy); } c10::complex asComplex() const { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - char* dummy; + char* dummy = nullptr; auto str = subtree(0)->stringValue(); // Complex numbers (a+bj, where a is non-zero) are parsed as an addition // between float/int a and a complex number "bj". When a is 0, a complex @@ -1264,7 +1262,11 @@ inline Expr pep604union_to_union(const Expr& expr) { expr.range(), Var::create(expr.range(), Ident::create(expr.range(), "Union")), List::create(expr.range(), members)); +#if defined(__clang__) return std::move(synthesised_union); +#else + return synthesised_union; +#endif } } // namespace torch::jit diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 7b45697b0aa6d9..87af6a1a634386 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -1757,12 +1757,9 @@ bool AliasDb::tryMove( // dependencies WorkingSet workingSet(toMove, *this); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int direction; + auto direction = kNextDirection; if (toMove->isAfter(movePoint)) { direction = kPrevDirection; - } else { - direction = kNextDirection; } auto curNode = toMove->next_in_graph[direction]; diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 73bbf7c10a6e38..44087074e89168 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1743,8 +1743,8 @@ struct OperatorMap { // TODO: return iterator std::vector getAllKeysAndValues() const { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector keys_values; + keys_values.reserve(map.size()); for (auto& symbol_mapping : map) { auto& vec = symbol_mapping.second; for (auto& pair : vec) { @@ -1819,8 +1819,8 @@ struct FunctionSchemaMap { // TODO: return iterator std::vector getAllKeysAndValues() const { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector keys_values; + keys_values.reserve(map.size()); for (auto& symbol_mapping : map) { auto& vec = symbol_mapping.second; for (auto& pair : vec) { diff --git a/torch/csrc/jit/jit_log.cpp b/torch/csrc/jit/jit_log.cpp index ef77bb7b642745..b94fc346f5f188 100644 --- a/torch/csrc/jit/jit_log.cpp +++ b/torch/csrc/jit/jit_log.cpp @@ -31,7 +31,6 @@ class JitLoggingConfig { std::unordered_map files_to_levels; std::ostream* out; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) JitLoggingConfig() : out(&std::cerr) { const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL"); logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level); diff --git a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp index b8b1ca6adc0dc4..60f059cd927847 100644 --- a/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp @@ -284,10 +284,6 @@ std::unordered_set _get_mobile_model_contained_types( std::unordered_set _get_mobile_model_contained_types( const std::vector& bytecode_ivalues) { std::unordered_set contained_types; - // To avoid parsing same type twice, declare $parsed_type_names_records and - // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as - // the hash to record which types are parsed. - std::unordered_set parsed_type_names_records; for (const auto i : c10::irange(1, bytecode_ivalues.size())) { const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements(); auto type_table_tuple = @@ -299,7 +295,6 @@ std::unordered_set _get_mobile_model_contained_types( // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" std::vector type_name_list; for (const auto& type_definition : type_table) { - std::unordered_set type_tokens; std::string type_name = type_definition.toStringRef(); type_name_list.emplace_back(type_name); } diff --git a/torch/csrc/jit/mobile/import_data.cpp b/torch/csrc/jit/mobile/import_data.cpp index 8f206230957280..1bd34e4a823ae9 100644 --- a/torch/csrc/jit/mobile/import_data.cpp +++ b/torch/csrc/jit/mobile/import_data.cpp @@ -61,8 +61,7 @@ c10::IValue IValueUnpickler::readArchive( std::stringstream picklename; picklename << archive_name << ".pkl"; at::DataPtr pickle_ptr; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t pickle_size; + size_t pickle_size = 0; std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str()); size_t bytes_read = 0; diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index 405136795df71d..857cb304291025 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -113,10 +113,8 @@ void layout(Stack& stack) { } void toPrimDType(Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; + bool non_blocking = false; + bool copy = false; pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); @@ -141,10 +139,8 @@ void boolTensor(Stack& stack) { } void toList(Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int elem_ty_val; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int dim_val; + int elem_ty_val = 0; + int dim_val = 0; at::Tensor t; pop(stack, elem_ty_val); diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 1aff3eac1e641b..766c0843026456 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -108,7 +108,7 @@ std::optional parseAutocast( TORCH_CHECK( dtype != c10::ScalarType::Undefined, "Autocast has invalid fast_dtype attribute"); - if (device == "cuda") { + if (device == "cuda" || device == "mps") { scope.context.gpu_enabled = enabled.value(); scope.context.gpu_scalar_type = dtype; } else if (device == "cpu") { diff --git a/torch/csrc/jit/passes/create_functional_graphs.cpp b/torch/csrc/jit/passes/create_functional_graphs.cpp index 036f30f916693d..86e9fa13893f61 100644 --- a/torch/csrc/jit/passes/create_functional_graphs.cpp +++ b/torch/csrc/jit/passes/create_functional_graphs.cpp @@ -80,7 +80,6 @@ struct FunctionalGraphSlicer { graph_->createWithSubgraph(prim::FunctionalGraph) ->insertBefore(block->return_node()); auto reverse_iter = block->nodes().reverse(); - std::vector graph_outputs; for (auto it = reverse_iter.begin(); it != reverse_iter.end();) { Node* n = *it++; diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index 0cca0d77f193e8..94610679e98e9d 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -40,7 +40,6 @@ static void flattenTupleInLoopParams(Node* n, size_t index) { Block* block = n->blocks().at(0); Node* block_node = n; - std::vector new_node_inputs = {}; auto new_construct_node = block->prependNode(block->owningGraph()->create(prim::TupleConstruct)); for (size_t j = 0; j < tt->elements().size(); ++j) { diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 50e034dd40f0b3..238b8e5c236efe 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -164,7 +164,6 @@ void PreprocessCaffe2Ops(std::shared_ptr& graph) { std::shared_ptr ToONNX( std::shared_ptr& graph, ::torch::onnx::OperatorExportTypes operator_export_type) { - auto constant_value_map = ConstantValueMap::getInstance(); ConstantValueMap::ClearMaps(); auto new_graph = std::make_shared(graph->current_scope()); py::dict env; diff --git a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp index de719a1eabc539..5a62e02b628e53 100644 --- a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp +++ b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; } @@ -31,8 +30,7 @@ void CastAllConstantToFloating(Block* block) { auto val_type = TensorType::create(val); if (dtype != at::ScalarType::Double && dtype != at::ScalarType::Float && dtype != at::ScalarType::Half) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int to_type; + int to_type = 0; switch (val.scalar_type()) { case at::ScalarType::Byte: case at::ScalarType::Char: @@ -71,5 +69,4 @@ void CastAllConstantToFloating(Block* block) { void CastAllConstantToFloating(const std::shared_ptr& graph) { CastAllConstantToFloating(graph->block()); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h index c9c40f2dd9a293..72321f0c27d706 100644 --- a/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h +++ b/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h @@ -4,9 +4,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // see .cpp for docs TORCH_API void CastAllConstantToFloating(const std::shared_ptr& graph); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 61d97057c5b429..75f1ae4c349aaa 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -9,8 +9,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -707,5 +706,4 @@ void ConstantFoldONNX( GRAPH_DUMP("After ConstantFoldONNX:", g); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h index d25ebee32a787e..899ae706ca8a2e 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.h +++ b/torch/csrc/jit/passes/onnx/constant_fold.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { const int ONNX_OPSET_9 = 9; const int ONNX_OPSET_10 = 10; @@ -30,6 +29,4 @@ void ConstantFoldONNX( std::map& paramDict, int opset_version); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 99c801dcf77367..8b9a6273ef619d 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -7,12 +7,7 @@ #include #include -namespace torch { -namespace jit { - -namespace onnx { -using namespace ::c10::onnx; -} +namespace torch::jit { // Meyer’s Singleton for C++ 14 ConstantValueMap& ConstantValueMap::getInstance() { @@ -290,7 +285,7 @@ void ConstantValueMap::ClearMaps() { // For debug only. void ConstantValueMap::PrintMaps() { - std::cout << "Rank/Shape Map:" << std::endl; + std::cout << "Rank/Shape Map:" << '\n'; for (const auto& x : ConstantValueMap::getInstance().rankMap) { std::stringstream ss; if (ConstantValueMap::getInstance().shapeMap.find(x.first) != @@ -308,45 +303,45 @@ void ConstantValueMap::PrintMaps() { } } ss << " (rank = " << x.second << ")"; - std::cout << "node " << x.first << ": " << ss.str() << std::endl; + std::cout << "node " << x.first << ": " << ss.str() << '\n'; } - std::cout << std::endl; - std::cout << "Value Map:" << std::endl; + std::cout << '\n'; + std::cout << "Value Map:" << '\n'; for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) { - std::cout << "node " << x.first << ": " << x.second << std::endl; + std::cout << "node " << x.first << ": " << x.second << '\n'; } - std::cout << std::endl; - std::cout << "TypeReliable Map:" << std::endl; + std::cout << '\n'; + std::cout << "TypeReliable Map:" << '\n'; size_t count = 0; for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) { std::cout << "(node " << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "UseInferredType Map:" << std::endl; + std::cout << '\n'; + std::cout << "UseInferredType Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) { std::cout << "(node " << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "ShapeValue Map:" << std::endl; + std::cout << '\n'; + std::cout << "ShapeValue Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().shapeValueMap) { std::cout << "(node " << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "InferredShape Map:" << std::endl; + std::cout << '\n'; + std::cout << "InferredShape Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().inferredShapeData) { std::cout << "(node " << x.first << ": "; @@ -360,29 +355,28 @@ void ConstantValueMap::PrintMaps() { std::cout << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << std::endl; - std::cout << "SymbolDim Map:" << std::endl; + std::cout << '\n'; + std::cout << "SymbolDim Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().symbolDimMap) { std::cout << "(" << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } - std::cout << "DimSymbol Map:" << std::endl; + std::cout << "DimSymbol Map:" << '\n'; count = 0; for (const auto& x : ConstantValueMap::getInstance().dimSymbolMap) { std::cout << "(" << x.first << ": " << x.second << "), "; count++; if (count % 10 == 0) { - std::cout << std::endl; + std::cout << '\n'; } } } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h index 9bea729a37ad54..60d4470c1b12cf 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.h +++ b/torch/csrc/jit/passes/onnx/constant_map.h @@ -13,8 +13,7 @@ C10_DIAGNOSTIC_POP() #include #include -namespace torch { -namespace jit { +namespace torch::jit { using ShapeDataMap = std::unordered_map; @@ -112,8 +111,7 @@ class ConstantValueMap { // Stores if all graph-level inputs have static shape std::optional allGraphInputsStatic; // True if reliable has been computed for all graph inputs - bool allGraphInputsReliableComputed; + bool allGraphInputsReliableComputed{}; }; -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp b/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp index cbb97c238e7618..009bb157737405 100644 --- a/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp +++ b/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp @@ -4,8 +4,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -99,5 +98,4 @@ void DeduplicateInitializers( buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/deduplicate_initializers.h b/torch/csrc/jit/passes/onnx/deduplicate_initializers.h index 60fbd8f089d4f8..f6da1601110994 100644 --- a/torch/csrc/jit/passes/onnx/deduplicate_initializers.h +++ b/torch/csrc/jit/passes/onnx/deduplicate_initializers.h @@ -4,14 +4,11 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void DeduplicateInitializers( std::shared_ptr& g, std::map& paramsDict, bool is_train); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp b/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp index b803ae35cf3ca5..50e5e4363528e0 100644 --- a/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp +++ b/torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -16,5 +15,4 @@ void EliminateUnusedItemsONNX(Block* b, ParamMap& paramsDict) { return; } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eliminate_unused_items.h b/torch/csrc/jit/passes/onnx/eliminate_unused_items.h index 74e658070edbd2..793a8c1041ff69 100644 --- a/torch/csrc/jit/passes/onnx/eliminate_unused_items.h +++ b/torch/csrc/jit/passes/onnx/eliminate_unused_items.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // EliminateUnusedItemsONNX pass is removing unused // initializers and inputs, this is needed because @@ -12,6 +11,4 @@ void EliminateUnusedItemsONNX( Block* b, std::map& paramDict); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eval_peephole.cpp b/torch/csrc/jit/passes/onnx/eval_peephole.cpp index d6be7039f53188..d25a215ecc8af6 100644 --- a/torch/csrc/jit/passes/onnx/eval_peephole.cpp +++ b/torch/csrc/jit/passes/onnx/eval_peephole.cpp @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -152,5 +151,4 @@ void EvalPeepholeONNX(std::shared_ptr& g, ParamMap& paramsDict) { GRAPH_DUMP("After EvalPeepholeONNX:", g); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/eval_peephole.h b/torch/csrc/jit/passes/onnx/eval_peephole.h index 6f8961d08fd5eb..1bd0bd4c24d225 100644 --- a/torch/csrc/jit/passes/onnx/eval_peephole.h +++ b/torch/csrc/jit/passes/onnx/eval_peephole.h @@ -4,13 +4,10 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void EvalPeepholeONNX( std::shared_ptr& g, std::map& paramDict); -} // namespace jit - -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp index 5407d9aa2986b1..75c44aead7caf3 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp @@ -8,19 +8,14 @@ #include #include -namespace torch { -namespace jit { - -namespace onnx { -using namespace ::c10::onnx; -} +namespace torch::jit { namespace { const int ONNX_OPSET_13 = 13; const int ONNX_TYPE_BOOL = 9; Node* CreateCastToBoolNode(Value* val, Graph* graph) { - Node* cast_node = graph->create(onnx::Cast); + Node* cast_node = graph->create(c10::onnx::Cast); cast_node->addInput(val); cast_node->i_(attr::to, ONNX_TYPE_BOOL); cast_node->output()->setType(BoolType::get()); @@ -149,7 +144,7 @@ std::vector ConvertSequenceDependencies(Node* node, int opset_version) { // Split the added scan_output back to expected tensor sequence. auto loop_output = loop_node->output(i - 2); Node* split_node = - loop_node->owningGraph()->create(onnx::SplitToSequence); + loop_node->owningGraph()->create(c10::onnx::SplitToSequence); loop_output->replaceAllUsesWith(split_node->output()); split_node->i_(attr::keepdims, 0); split_node->addInput(loop_output); @@ -191,7 +186,7 @@ std::vector ConvertSequenceDependencies(Node* node, int opset_version) { return new_outputs; } -Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) { +Node* ONNXOptionalNode(const OptionalTypePtr& opt_type, Graph* g) { TORCH_INTERNAL_ASSERT(opt_type); TypePtr elem_type = opt_type->getElementType(); Node* opt_node = g->create(::c10::onnx::Optional, 1); @@ -208,7 +203,7 @@ Node* ONNXOptionalNode(OptionalTypePtr opt_type, Graph* g) { // 2. Loop Op: insert Optional node before output, if input is Optional type // or output type is None. void ReplaceBlockOutputWithOptional( - OptionalTypePtr opt_type, + const OptionalTypePtr& opt_type, Block* block, size_t i) { Node* opt_node = ONNXOptionalNode(opt_type, block->owningGraph()); @@ -235,9 +230,9 @@ void FixupONNXSubblockOutputs(Node* n) { // Identity(None). Also enables shape inference later on, since // ONNX shape inference doesn't handle None. if (output->type()->cast()) { - id_node = block->owningGraph()->create(onnx::Optional); + id_node = block->owningGraph()->create(c10::onnx::Optional); } else { - id_node = block->owningGraph()->create(onnx::Identity); + id_node = block->owningGraph()->create(c10::onnx::Identity); id_node->addInput(output); } id_node->insertBefore(block->return_node()); @@ -741,5 +736,4 @@ void FixupONNXControlflowNodeOutputs(Node* n) { } } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h index 8d33c2dd1fb5e4..f7fc049de99814 100644 --- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h @@ -2,11 +2,9 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { std::vector FixupONNXControlflowNode(Node* n, int opset_version); void FixupONNXControlflowNodeOutputs(Node* n); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index febf412e5d1224..c988b9243669cc 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -2,9 +2,7 @@ #include #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace { @@ -75,9 +73,9 @@ struct FunctionExtractor { using FunctionCtxPtr = FunctionContext*; using func_ctx_map = std::unordered_map; - static bool IsValidScope(ScopePtr s); + static bool IsValidScope(const ScopePtr& s); static std::optional InferScope(Node* n); - static bool IsAncestor(ScopePtr parent, ScopePtr child); + static bool IsAncestor(const ScopePtr& parent, ScopePtr child); static std::optional FindCommonAncestor(ScopePtr a, ScopePtr b); static std::optional FindCommonAncestor(const scope_list& scopes); std::shared_ptr ConstructFuncGraph(FunctionContext& ctx); @@ -88,7 +86,9 @@ struct FunctionExtractor { scope_ctx_map& scope_ctxs, const std::shared_ptr& graph); - static void HandleNoScopeNodes(scope_ctx_map&, node_list no_scope_nlist); + static void HandleNoScopeNodes( + scope_ctx_map&, + const node_list& no_scope_nlist); std::tuple PartitionNodesByScope(Block* b); scope_ctx_map PartitionNodesByScope(const std::shared_ptr& graph); static std::unordered_map PartitionIdenticalScopes( @@ -279,11 +279,11 @@ void FunctionExtractor::DebugPrintGraphWithFunction( GRAPH_UPDATE("Main graph: ", g->toString()); } -bool FunctionExtractor::IsValidScope(ScopePtr s) { +bool FunctionExtractor::IsValidScope(const ScopePtr& s) { return !s->isRoot() && !s->isBlank(); } -bool FunctionExtractor::IsAncestor(ScopePtr parent, ScopePtr child) { +bool FunctionExtractor::IsAncestor(const ScopePtr& parent, ScopePtr child) { if (!IsValidScope(parent) || !IsValidScope(child) || parent->getDepth() >= child->getDepth()) { return false; @@ -376,7 +376,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { std::all_of( output_scopes.begin(), output_scopes.end(), - [&output_scopes](ScopePtr scope) -> bool { + [&output_scopes](const ScopePtr& scope) -> bool { return IsValidScope(scope) && scope == output_scopes.at(0); })) { return output_scopes.at(0); @@ -385,7 +385,7 @@ std::optional FunctionExtractor::InferScope(Node* n) { std::all_of( input_scopes.begin(), input_scopes.end(), - [&input_scopes](ScopePtr scope) -> bool { + [&input_scopes](const ScopePtr& scope) -> bool { return IsValidScope(scope) && scope == input_scopes.at(0); })) { return input_scopes.at(0); @@ -822,7 +822,7 @@ void FunctionExtractor::ScopeContext::PopulateInputsOutputs( void FunctionExtractor::HandleNoScopeNodes( scope_ctx_map& scope_ctxs, - node_list no_scope_nlist) { + const node_list& no_scope_nlist) { GRAPH_UPDATE("No scope node count: ", no_scope_nlist.size()); for (auto n : no_scope_nlist) { TORCH_WARN( @@ -1181,6 +1181,4 @@ void ONNXTrackScopeAttributes( } } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/function_extraction.h b/torch/csrc/jit/passes/onnx/function_extraction.h index 40555f8e3561ca..fea0d23e703010 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.h +++ b/torch/csrc/jit/passes/onnx/function_extraction.h @@ -2,9 +2,6 @@ #include -namespace torch { -namespace jit { - // This api will be used by serialization/export.cpp to extract function // information. It should do conversion on graph to // 1. Extract subgraph pattern of functions and define as local function @@ -15,7 +12,7 @@ namespace jit { // represent these info inside Graph object. // export.cpp will serialize the ONNX model with function_proto with // above information. -namespace onnx { +namespace torch::jit::onnx { // The following return types are used to track information regarding function // attributes, that are unable to be traced through Torch IR. @@ -64,7 +61,4 @@ TORCH_API void ONNXTrackScopeAttributes( std::shared_ptr& graph, std::map& attributes); -} // namespace onnx - -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/function_substitution.cpp b/torch/csrc/jit/passes/onnx/function_substitution.cpp index 81bfa3fd6caf59..5a8e24b015da35 100644 --- a/torch/csrc/jit/passes/onnx/function_substitution.cpp +++ b/torch/csrc/jit/passes/onnx/function_substitution.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace { @@ -193,5 +192,4 @@ void ONNXFunctionCallSubstitution(Graph& graph) { GRAPH_DUMP("After function call substitution calls: ", &graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/function_substitution.h b/torch/csrc/jit/passes/onnx/function_substitution.h index b07acef2c3fc66..3571bab936e2c4 100644 --- a/torch/csrc/jit/passes/onnx/function_substitution.h +++ b/torch/csrc/jit/passes/onnx/function_substitution.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void ONNXFunctionCallSubstitution(Graph& graph); } -} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 9d4c5061414c5a..64a4bb9bdfab36 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -12,8 +12,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -296,5 +295,4 @@ void ONNXLintGraph(const std::shared_ptr& graph) { " constants."); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/helper.h b/torch/csrc/jit/passes/onnx/helper.h index 9e09c638779ef5..09b31576998a5d 100644 --- a/torch/csrc/jit/passes/onnx/helper.h +++ b/torch/csrc/jit/passes/onnx/helper.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { // Utility functions for PyTorch to ONNX conversion. @@ -73,5 +72,4 @@ class ScalarTypeHashFunction { } }; -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 6a1e3b08f3b9a8..268cc4dc7f9cbb 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -191,5 +190,4 @@ std::pair> list_module_parameters( return std::make_pair(moduleClone, parameterIValues); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.h b/torch/csrc/jit/passes/onnx/list_model_parameters.h index 50d1cea2b8fe00..114b3b2d894139 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.h +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.h @@ -3,11 +3,9 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API std::pair> list_module_parameters( const Module& module); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/naming.cpp b/torch/csrc/jit/passes/onnx/naming.cpp index ed62c79231899d..62fd67e1d2ca47 100644 --- a/torch/csrc/jit/passes/onnx/naming.cpp +++ b/torch/csrc/jit/passes/onnx/naming.cpp @@ -3,9 +3,7 @@ #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace ONNXScopeName { @@ -16,7 +14,7 @@ const std::string name_separator = "::"; namespace { std::string nameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator, NameFunc name_func) { std::string out = (*name_func)(scope); @@ -32,7 +30,7 @@ std::string nameFromRoot( } std::pair parseNameFromScope( - torch::jit::ScopePtr scope) { + const torch::jit::ScopePtr& scope) { std::string full_name = scope->name().toUnqualString(); auto pos = full_name.find(name_separator); TORCH_CHECK( @@ -55,7 +53,7 @@ std::string variableName(torch::jit::ScopePtr scope) { } std::string variableNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator) { return nameFromRoot(scope, layer_separator, &variableName); } @@ -65,12 +63,12 @@ std::string className(torch::jit::ScopePtr scope) { } std::string classNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator) { return nameFromRoot(scope, layer_separator, &className); } -bool isCompatibleScope(torch::jit::ScopePtr scope) { +bool isCompatibleScope(const torch::jit::ScopePtr& scope) { return !scope->isRoot() && !scope->isBlank() && (std::string(scope->name().toUnqualString()).find(name_separator) != std::string::npos); @@ -89,7 +87,7 @@ class NodeNameGenerator { virtual void CreateNodeName(Node* n) = 0; void PopulateNodeNames(Block*); void UpdateOutputsNames(Node* n); - bool IsGraphOutput(const Value* v, const std::shared_ptr graph) const; + bool IsGraphOutput(const Value* v, const std::shared_ptr& graph) const; protected: std::string CreateUniqueName( @@ -99,19 +97,21 @@ class NodeNameGenerator { std::unordered_map node_names_; std::unordered_map base_node_name_counts_; std::shared_ptr graph_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string layer_separator_ = "/"; }; NodeNameGenerator::~NodeNameGenerator() = default; class ScopedNodeNameGenerator : public NodeNameGenerator { public: - ScopedNodeNameGenerator(std::shared_ptr g) : NodeNameGenerator(g){}; + ScopedNodeNameGenerator(std::shared_ptr g) + : NodeNameGenerator(std::move(g)){}; protected: void CreateNodeName(Node* n) override; private: - std::string GetFullScopeName(ScopePtr scope); + std::string GetFullScopeName(const ScopePtr& scope); std::unordered_map full_scope_names_; std::unordered_map base_scope_name_counts_; }; @@ -131,7 +131,7 @@ std::string NodeNameGenerator::CreateUniqueName( bool NodeNameGenerator::IsGraphOutput( const Value* v, - const std::shared_ptr graph) const { + const std::shared_ptr& graph) const { for (const auto* graph_output : graph->outputs()) { if (v == graph_output) { return true; @@ -185,7 +185,7 @@ void ScopedNodeNameGenerator::CreateNodeName(Node* n) { n->s_(Symbol::attr(::torch::onnx::kOnnxNodeNameAttribute), node_names_[n]); } -std::string ScopedNodeNameGenerator::GetFullScopeName(ScopePtr scope) { +std::string ScopedNodeNameGenerator::GetFullScopeName(const ScopePtr& scope) { if (full_scope_names_.find(scope) == full_scope_names_.end()) { auto full_scope_name = ONNXScopeName::variableNameFromRoot(scope, layer_separator_); @@ -202,6 +202,4 @@ void AssignScopedNamesForNodeAndValue(std::shared_ptr& graph) { node_name_generator->PopulateNodeNames(); } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/naming.h b/torch/csrc/jit/passes/onnx/naming.h index 0472d041a92331..905d47bf541b4e 100644 --- a/torch/csrc/jit/passes/onnx/naming.h +++ b/torch/csrc/jit/passes/onnx/naming.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace ONNXScopeName { @@ -13,18 +11,16 @@ std::string createFullScopeName( const std::string& variable_name); std::string variableName(torch::jit::ScopePtr scope); std::string variableNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator); std::string className(torch::jit::ScopePtr scope); std::string classNameFromRoot( - torch::jit::ScopePtr scope, + const torch::jit::ScopePtr& scope, const std::string& layer_separator); -bool isCompatibleScope(torch::jit::ScopePtr scope); +bool isCompatibleScope(const torch::jit::ScopePtr& scope); } // namespace ONNXScopeName TORCH_API void AssignScopedNamesForNodeAndValue(std::shared_ptr& graph); -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/onnx_log.cpp b/torch/csrc/jit/passes/onnx/onnx_log.cpp index eff2ae4d5a3231..f749690145601f 100644 --- a/torch/csrc/jit/passes/onnx/onnx_log.cpp +++ b/torch/csrc/jit/passes/onnx/onnx_log.cpp @@ -1,9 +1,7 @@ #include #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { namespace { bool log_enabled = false; @@ -26,6 +24,4 @@ std::ostream& _get_log_output_stream() { return out ? *out : std::cout; } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/onnx_log.h b/torch/csrc/jit/passes/onnx/onnx_log.h index a659a122342760..b3343df4c6e387 100644 --- a/torch/csrc/jit/passes/onnx/onnx_log.h +++ b/torch/csrc/jit/passes/onnx/onnx_log.h @@ -4,9 +4,7 @@ #include #include -namespace torch { -namespace jit { -namespace onnx { +namespace torch::jit::onnx { TORCH_API bool is_log_enabled(); @@ -22,6 +20,4 @@ TORCH_API std::ostream& _get_log_output_stream(); << ::c10::str(__VA_ARGS__) << std::endl; \ } -} // namespace onnx -} // namespace jit -} // namespace torch +} // namespace torch::jit::onnx diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index 18c31ea656610d..407a8c79dfb7fe 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -23,8 +23,7 @@ typedef SSIZE_T ssize_t; #endif -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -818,8 +817,7 @@ static void fuseLogSoftmaxNllLoss(Block* b) { if (it->kind() == onnx::NegativeLogLikelihoodLoss) { auto prev = it->input(0)->node(); Node* origNllLossNode = *it; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Node* origLogSoftmaxNode; + Node* origLogSoftmaxNode = nullptr; // Check for patterns especially in cases with autocasting enabled // in which a cast node is inserted before the NegativeLogLikelihoodLoss @@ -1069,5 +1067,4 @@ void PeepholeOptimizeONNX( GRAPH_DUMP("After PeepholeOptimizeONNX", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/peephole.h b/torch/csrc/jit/passes/onnx/peephole.h index 7d23267310ab85..3a3819974d5625 100644 --- a/torch/csrc/jit/passes/onnx/peephole.h +++ b/torch/csrc/jit/passes/onnx/peephole.h @@ -2,13 +2,11 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void PeepholeOptimizeONNX( std::shared_ptr& graph, int opset_version, bool fixed_batch_size); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp index 63ad0040237e25..f7c0e0a7cac0c0 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { // onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0, // so before converting the ints to tensors we need to cast them to floats. @@ -43,5 +42,4 @@ void PrepareDivisionForONNX(const std::shared_ptr& graph) { GRAPH_DUMP("After PrepareDivisionForONNX: ", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h index 1bfd8eaef311d2..b9e25861b4778c 100644 --- a/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h +++ b/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // Prepare division ops for ONNX export. This is necessary for and only used // by ONNX export. @@ -15,5 +14,4 @@ namespace jit { // TORCH_API void PrepareDivisionForONNX(const std::shared_ptr& graph); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index 89c3a42a9c0814..5f35a85b2aa891 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -229,5 +228,4 @@ void PreprocessForONNX(std::shared_ptr& graph) { GRAPH_DUMP("After fuseListAndListUnpack: ", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h index dc180a59c1b6b7..541e4339768e23 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { void PreprocessForONNX(std::shared_ptr& graph); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp index 8e0258a3049e02..4fa3068c3d1e46 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp @@ -13,8 +13,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { namespace { @@ -368,7 +367,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) { } } -static void PrepareForRemoveMutations(std::shared_ptr graph) { +static void PrepareForRemoveMutations(const std::shared_ptr& graph) { MutationRemover mr(graph); PrepareForRemoveMutations(mr, graph->block()); GRAPH_DUMP("After PrepareForRemoveMutations: ", graph); @@ -438,23 +437,23 @@ std::string InplaceConverter::ValueTracker::toString() const { // ss << "Current graph: " << graph_->toString() << std::endl; ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values." - << std::endl; - ss << "value_to_sorted_aliases_: " << std::endl; + << '\n'; + ss << "value_to_sorted_aliases_: " << '\n'; size_t idx = 0; for (const auto& it : value_to_sorted_aliases_) { - ss << "Value[" << idx << "]: " << it.first->debugName() << std::endl; + ss << "Value[" << idx << "]: " << it.first->debugName() << '\n'; ss << " Mapping to "; for (auto v : it.second) { ss << v->debugName() << " "; } - ss << std::endl; + ss << '\n'; idx++; } - ss << "alias_to_value_: " << std::endl; + ss << "alias_to_value_: " << '\n'; for (auto it : alias_to_value_) { ss << " Alias " << it.first->debugName(); - ss << " map to " << it.second->debugName() << std::endl; + ss << " map to " << it.second->debugName() << '\n'; } return ss.str(); @@ -890,5 +889,4 @@ void RemoveInplaceOpsForONNX( ic.convertMutationForONNX(); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h index 6afd28268bdeb9..996c7f80a6c1f6 100644 --- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h +++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h @@ -2,12 +2,10 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void RemoveInplaceOpsForONNX( const std::shared_ptr& graph, Module* model); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 009566499275b4..5369dedac9762b 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -479,5 +478,4 @@ void ScalarTypeAnalysisNodeForONNX(Node* n) { ImplicitCastNodeForONNX(n); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h index 90a433a3ce47ed..8c5051220e84f6 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void ScalarTypeAnalysisForONNX( const std::shared_ptr& graph, @@ -11,5 +10,4 @@ TORCH_API void ScalarTypeAnalysisForONNX( int opset_version); void ScalarTypeAnalysisNodeForONNX(Node* n); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 294d0682246e9f..9a6dcd9bd761dd 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -22,16 +22,15 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { inline bool PyNone_Check(PyObject* o) { return o == Py_None; } std::pair MergeInferredType( - TypePtr existing_type, - TypePtr inferred_type) { + const TypePtr& existing_type, + const TypePtr& inferred_type) { auto new_list_type = inferred_type->cast(); auto use_inferred_type = false; if (new_list_type) { @@ -75,8 +74,8 @@ std::pair MergeInferredType( void MergeInferredTypeAndSetMap( Value* dest_v, - TypePtr existing_type, - TypePtr inferred_type) { + const TypePtr& existing_type, + const TypePtr& inferred_type) { auto [mergedType, inferred] = MergeInferredType(existing_type, inferred_type); dest_v->setType(mergedType); ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); @@ -256,7 +255,7 @@ bool CustomSettype(Node* node) { Value* CloneValueFromListConstruct( Value* v, - std::shared_ptr n_graph, + const std::shared_ptr& n_graph, int opset_version) { auto lc_node = v->node(); TORCH_INTERNAL_ASSERT(lc_node->kind() == ::c10::prim::ListConstruct); @@ -355,7 +354,7 @@ Node* CloneNodeToGraph( return clone_node; } -bool HasValidType(TypePtr type, std::string name) { +bool HasValidType(const TypePtr& type, const std::string& name) { if (auto t_type = type->cast()) { if (!t_type->scalarType().has_value()) { GRAPH_UPDATE("Input ", name, " is missing tensor datatype."); @@ -371,7 +370,7 @@ bool HasValidType(TypePtr type, std::string name) { return true; } -bool IsGraphValidForInference(std::shared_ptr graph) { +bool IsGraphValidForInference(const std::shared_ptr& graph) { // Verify if every input has type (either Tensor, Sequence or Optional) and // scalar type. This is a requirement for ONNX graph inputs. for (auto in : graph->inputs()) { @@ -381,7 +380,7 @@ bool IsGraphValidForInference(std::shared_ptr graph) { } void ConvertGraphToONNXProto( - std::shared_ptr graph, + const std::shared_ptr& graph, std::shared_ptr& model_proto, SymbolDimMap& symbol_dim_map, DimSymbolMap& dim_symbol_map, @@ -1652,7 +1651,8 @@ void SpecialPostProcess(Node* n) { auto seq_node = n->input(0)->node(); auto t_type = n->input(1)->type()->cast(); - auto update_sequence_empty_dtype = [](Node* n, TensorTypePtr t_type) { + auto update_sequence_empty_dtype = [](Node* n, + const TensorTypePtr& t_type) { TORCH_INTERNAL_ASSERT(n && n->kind() == ::c10::onnx::SequenceEmpty); TORCH_INTERNAL_ASSERT(t_type && t_type->scalarType().has_value()); auto scalar_type = t_type->scalarType().value(); @@ -1711,7 +1711,7 @@ void SpecialPostProcess(Node* n) { return nullptr; }; return find_sequence_empty_impl( - input, t_type, find_sequence_empty_impl); + input, std::move(t_type), find_sequence_empty_impl); }; if (seq_node && t_type && t_type->scalarType()) { @@ -2122,9 +2122,9 @@ void ONNXShapeTypeInference( case ::c10::onnx::Gather: { auto* schema_registry = onnx::OpSchemaRegistry::Instance(); onnx::ShapeInferenceOptions options{ - /*check_type=*/false, - /*error_mode=*/false, - /*enable_data_propagation=*/true}; + /*check_type_val=*/false, + /*strict_mode_val=*/0, + /*data_prop_val=*/true}; onnx::shape_inference::InferShapes( *model_proto, schema_registry, options, &inferred_shape_data); break; @@ -2509,5 +2509,4 @@ void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) { } } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index 685ca39c16dec1..bca534654febb8 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -6,8 +6,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { // Merges existing_type and inferred_type. // Returns {merged type, whether or not inferred_type was used}. @@ -28,13 +27,13 @@ namespace jit { // ONNX represents list of scalars by 1-d Tensor. Return inferred type since // it is more compatible with ONNX. std::pair MergeInferredType( - TypePtr existing_type, - TypePtr inferred_type); + const TypePtr& existing_type, + const TypePtr& inferred_type); void MergeInferredTypeAndSetMap( Value* dest_v, - TypePtr existing_type, - TypePtr inferred_type); + const TypePtr& existing_type, + const TypePtr& inferred_type); // Update graph input types with dynamic axes info. // Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol. @@ -96,5 +95,4 @@ void UpdateReliable( void UpdateReliable(torch::jit::Node* n); void UpdateShapeConstantIfReliable(torch::jit::Value* output); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 850070aa0b2fa3..7a4f95ec69763a 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -14,8 +14,8 @@ #include using ::c10::Dispatcher; -namespace torch { -namespace jit { + +namespace torch::jit { namespace onnx { using namespace ::c10::onnx; @@ -299,10 +299,6 @@ void unpackQuantizedWeightsHelper( torch::List stride_int, padding_int, dilation_int, output_padding_int; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t groups_int; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t transpose_int; if (itr->second.isTuple()) { // Pre-unpacked weights. Comes from Conv/Linear weights which are @@ -415,9 +411,9 @@ void unpackQuantizedWeightsHelper( } idx++; } - groups_int = conv_params_packed[idx].item(); + auto groups_int = conv_params_packed[idx].item(); idx++; - transpose_int = conv_params_packed[idx].item(); + auto transpose_int = conv_params_packed[idx].item(); idx++; TORCH_INTERNAL_ASSERT( idx == conv_params_packed.numel(), @@ -459,11 +455,10 @@ void unpackQuantizedWeightsHelper( for (const auto& d : dilation_ivalue) { dilation_int.emplace_back(d.toTensor()[0].item()); } - groups_int = groups_ivalue.toTensor()[0].item(); + groups = groups_ivalue.toTensor()[0].item(); stride = stride_int; padding = padding_int; dilation = dilation_int; - groups = groups_int; if (expect_output_padding) { auto output_padding_ivalue = @@ -770,5 +765,4 @@ void insertPermutes( GRAPH_DUMP("After insertPermutes: ", graph); } -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h index aa9b7af63478c0..70a99b4ef1859b 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h @@ -6,8 +6,7 @@ #include -namespace torch { -namespace jit { +namespace torch::jit { TORCH_API void UnpackQuantizedWeights( std::shared_ptr& graph, @@ -15,5 +14,4 @@ TORCH_API void UnpackQuantizedWeights( TORCH_API void insertPermutes( std::shared_ptr& graph, std::map& paramsDict); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index f906efacceca7b..9aacd481a55b04 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -1120,8 +1120,7 @@ void InsertObserversHelper::fillBoundaryValueMap( // offset of input for the caller node, since the first // input of CallFunction is the function node and the graph // for CallFunction start with actual input - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t input_offset; + size_t input_offset = 0; if (n->kind() == prim::CallMethod) { auto m_opt = getInvokedModuleOpt(module, n, self); if (!m_opt.has_value()) { @@ -1469,8 +1468,7 @@ InsertObserversHelper::insertObserversFor( if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) { script::Module m; std::shared_ptr g; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t input_offset; + size_t input_offset = 0; bool is_udf_for_subblock = is_user_defined_function; if (n->kind() == prim::CallMethod) { auto m_opt = getInvokedModuleOpt(module, n, self); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 3e0b265ae70c8f..05c19bdb38a1fd 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -513,7 +513,6 @@ void ReplicateChooseQParamsQuantDequant(std::shared_ptr& graph) { Node* pattern_choose_qparam = choose_qparam_val->node(); std::vector nodes_to_rewrite; - std::vector choose_qparam_nodes_to_rewrite; for (const Match& match : matches) { Node* matched_dequantize = match.nodes_map.at(pattern_dequant); Node* matched_quantize = match.nodes_map.at(pattern_quant); @@ -1557,7 +1556,6 @@ QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams( "getQSchemeAndParamMap expects the corresponding observer for ", v->debugName(), " exists."); - std::vector qparams_graph_values; QuantOpParams quant_op_params; TORCH_CHECK( diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index 851548862dfc47..aeba208ea98e9c 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -808,13 +808,6 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype) "%count_include_pad", "%divisor_override"}); - std::string common_general_value_op = R"( - %r_scale : float = aten::q_scale(%a_quant) - %r_zero_point : int = aten::q_zero_point(%a_quant) - %r_dtype : int = prim::dtype(%a_quant) - %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype) - return (%r_quant) )"; - auto avg_pool3d = getInputTensorQParamOpFusionInfo( "aten::avg_pool3d", {"%kernel_size", diff --git a/torch/csrc/jit/passes/quantization/register_packed_params.cpp b/torch/csrc/jit/passes/quantization/register_packed_params.cpp index 1d7dcfe72eea56..c3696cdc5109cf 100644 --- a/torch/csrc/jit/passes/quantization/register_packed_params.cpp +++ b/torch/csrc/jit/passes/quantization/register_packed_params.cpp @@ -59,7 +59,6 @@ std::unordered_set RegisterPrePackParams( int64_t uid = 0; // int + method name gives unique identifier auto graph = m.get_method(method_name).graph(); std::stack blocks_to_visit; - std::unordered_set nodes_to_delete; blocks_to_visit.push(graph->block()); std::string attr_name_base = attr_prefix + "_" + method_name + "_ondevice_ptq_packed_weight_"; diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index f4dfc4ce99c940..0cc07a18c05eba 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -133,7 +133,6 @@ void mergeSubgraph(Node* mergeTo, Node* mergeFrom) { } ++it; - std::vector merged_nodes; while (it != end_it) { Node* node = *it; ++it; diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 555e7f50c3dd4e..eee1cf05b1201c 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -31,10 +31,6 @@ #include #include -#ifdef USE_C10D_NCCL -#include -#include -#endif #include #include #include @@ -1077,7 +1073,11 @@ inline py::object createPyObjectForStack(Stack&& stack) { return_values[ret] = toPyObject(std::move(stack[ret])); } +#if defined(__clang__) return std::move(return_values); +#else + return return_values; +#endif } // TODO: Remove once we clean up the GraphExecutor usage. diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 7a815e815d8e96..324fc37e080c69 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -36,7 +36,6 @@ struct ArgumentInfo { return requires_grad_; } int dim() const { - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return dim_; } at::ScalarType type() const { @@ -104,7 +103,6 @@ struct ArgumentSpec { if (arg.defined_) { arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad(); arg.dim_ = t->dim(); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) at::Device device = t->device(); arg.dev_type_ = // NOLINTNEXTLINE(bugprone-signed-char-misuse) @@ -117,8 +115,7 @@ struct ArgumentSpec { } void combineHash(const ArgumentInfo& arg) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ArgumentInfo::plain_data_type arg_data; + ArgumentInfo::plain_data_type arg_data = 0; std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo)); hash_code = c10::hash_combine(hash_code, arg_data); } @@ -242,12 +239,10 @@ static_assert( struct CompleteArgumentInfo; struct CompleteArgumentSpec { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CompleteArgumentSpec(bool with_grad, at::ArrayRef inputs) : hash_code(0), ninputs(inputs.size()) { int32_t all_dims = 0; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - const int32_t num_inputs = inputs.size(); + const auto num_inputs = inputs.size(); for (const auto i : c10::irange(num_inputs)) { if (!inputs[i].isTensor()) continue; @@ -258,7 +253,6 @@ struct CompleteArgumentSpec { data.resize(ninputs + all_dims * 2); // and reinterpret our data array as these structs - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) auto* pods = reinterpret_cast(data.data()); int64_t* next_dim = sizes_strides(); int32_t total_dims = 0; @@ -270,7 +264,6 @@ struct CompleteArgumentSpec { pod.defined = t.defined(); if (pod.defined) { pod.type = static_cast(t.scalar_type()); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) at::Device device = t.device(); // NOLINTNEXTLINE(bugprone-signed-char-misuse) pod.dev_type = static_cast::type>( @@ -394,7 +387,6 @@ struct CompleteArgumentInfo { int sizes_strides_offset(int j) const { if (j == 0) return 0; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return 2 * pod(j - 1).total_dims; } const CompleteArgumentInfoPOD& pod(int j) const { diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 047a35e417fff8..d525fcaecd816b 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -206,8 +206,7 @@ class GradientHelper { } else if (node->kind() == prim::ConstantChunk) { auto* g = node->owningGraph(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Value* input_list; + Value* input_list = nullptr; if (grad_values.size() == 1 && grad_values[0]->type()->isSubtypeOf(*ListType::ofTensors())) { input_list = grad_values[0]; @@ -575,8 +574,7 @@ static void foldSizeIfNotEqual(Node* node) { // insert in front of _grad_sum_to_size WithInsertPoint guard(node); IValue ival{}; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Value* size; + Value* size = nullptr; if (input_size != output_size) { size = node->owningGraph()->insertConstant(*input_size); } else { diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 9762106340b404..8517e6a94b57b6 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -340,8 +340,7 @@ struct CodeImpl { int reg = registerFor(input); bool moved = input->uses().size() == ++use_count_[input]; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - OpCode op; + OpCode op{}; if (input->node()->kind() == prim::Constant) { op = LOADC; } else if (moved) { diff --git a/torch/csrc/jit/runtime/jit_trace.cpp b/torch/csrc/jit/runtime/jit_trace.cpp index cff9d5f954e813..b25088b32ecae2 100644 --- a/torch/csrc/jit/runtime/jit_trace.cpp +++ b/torch/csrc/jit/runtime/jit_trace.cpp @@ -56,8 +56,8 @@ Node* traceNode(Node* node, TracingData& td, Stack& stack) { } void eraseAllOutputs(Node* opt_pn) { - // NOLINTNEXTLINE - for (int i = opt_pn->outputs().size() - 1; i >= 0; i--) { + for (auto i = static_cast(opt_pn->outputs().size()) - 1; i >= 0; + i--) { opt_pn->eraseOutput(i); } } @@ -275,10 +275,12 @@ void insertTracingNodes(Block* block, ProfilingRecord* pr, TracingData& td) { // nodes and the outputs of the node in the scripted graph. // There are a few subtleties with tracing Ifs and Loops // discussed above -std::shared_ptr TraceGraph(std::shared_ptr graph, Stack& stack) { +std::shared_ptr TraceGraph( + const std::shared_ptr& graph, + Stack& stack) { TracingData td; GRAPH_DUMP("Before Inline:", graph); - Inline(*graph.get()); + Inline(*graph); EliminateDeadCode(graph); GRAPH_DUMP("After Inline:", graph); auto pr = ProfilingRecord::instrumentGraph(graph); diff --git a/torch/csrc/jit/runtime/jit_trace.h b/torch/csrc/jit/runtime/jit_trace.h index 12be844e35a91a..9b29501eeb3f91 100644 --- a/torch/csrc/jit/runtime/jit_trace.h +++ b/torch/csrc/jit/runtime/jit_trace.h @@ -3,6 +3,6 @@ namespace torch::jit { TORCH_API std::shared_ptr TraceGraph( - std::shared_ptr graph, + const std::shared_ptr& graph, Stack& stack); } // namespace torch::jit diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index a057367af81c96..abbdf44ec60518 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -287,12 +287,12 @@ void listAdd(Stack& stack) { c10::List ret = make_result_list(a.elementType()); if (a.use_count() == 1) { - ret = std::move(a); + ret = a; } else { ret = a.copy(); } - ret.append(std::move(b)); + ret.append(b); push(stack, std::move(ret)); } @@ -300,7 +300,7 @@ void listAdd(Stack& stack) { void listInplaceAdd(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); - a.append(std::move(b)); + a.append(b); push(stack, std::move(a)); } diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index f6eccede28bab1..d20c5d6a0fec54 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -13,18 +13,11 @@ #include #include #include -#include -#include #include -#include #include -#include #include #include #include -#include -#include -#include #include #include @@ -259,8 +252,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA( "aten::__range_length(int lo, int hi, int step) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t lo, hi, step; + int64_t lo = 0, hi = 0, step = 0; pop(stack, lo, hi, step); // error handling when step_val = 0 during runtime if (step == 0) { @@ -279,8 +271,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA( "aten::__derive_index(int index, int start, int step) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t index, start, step; + int64_t index = 0, start = 0, step = 0; pop(stack, index, start, step); push(stack, start + index * step); }, @@ -336,8 +327,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i; + int64_t i = 0; pop(stack, i); push(stack, (bool)i); }, @@ -345,8 +335,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Bool.float(float a) -> bool"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d; + double d = 0; pop(stack, d); push(stack, (bool)d); }, @@ -362,8 +351,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Int.bool(bool a) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; + bool b = false; pop(stack, b); push(stack, static_cast(b)); }, @@ -371,8 +359,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Int.float(float a) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d; + double d = 0; pop(stack, d); push(stack, static_cast(d)); }, @@ -394,8 +381,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA("aten::Int.str(str a) -> int"), [](Stack& stack) { auto s = pop(stack).toString(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::string::size_type sz; + std::string::size_type sz = 0; int64_t val = static_cast(std::stoll(s->string(), &sz)); if (sz == s->string().size()) { push(stack, val); @@ -432,8 +418,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Float.int(int a) -> float"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t i; + int64_t i = 0; pop(stack, i); push(stack, (float)i); }, @@ -441,8 +426,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::Float.bool(bool a) -> float"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool b; + bool b = false; pop(stack, b); push(stack, (float)b); }, @@ -451,8 +435,7 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA("aten::Float.str(str a) -> float"), [](Stack& stack) { auto s = pop(stack).toString(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - std::string::size_type sz; + std::string::size_type sz = 0; double b = std::stod(s->string(), &sz); if (sz == s->string().size()) { push(stack, b); @@ -847,7 +830,7 @@ static const std::vector opGenArgs{ ss << i; } drop(stack, num_inputs); - ss << std::endl; + ss << '\n'; auto* handler = getPrintHandler(); TORCH_INTERNAL_ASSERT(handler); handler(ss.str()); @@ -1036,8 +1019,7 @@ static const std::vector opGenArgs{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::pow.int_to_int(int a, int b) -> int"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t a, b; + int64_t a = 0, b = 0; pop(stack, a, b); push(stack, powWrapper(a, b)); }, @@ -1270,10 +1252,8 @@ static const std::vector opGenArgs{ TORCH_SELECTIVE_SCHEMA( "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; + bool non_blocking = false; + bool copy = false; pop(stack, non_blocking, copy); std::optional scalarType = pop(stack).toOptional(); @@ -1709,14 +1689,12 @@ static const std::vector dict_ops{ }; RegisterOperators reg_dict_ops(createOperators(dict_ops)); -// NOLINTNEXTLINE(clang-diagnostic-unused-function) constexpr c10::AliasAnalysisKind aliasAnalysisFromSchema() { return c10::AliasAnalysisKind::FROM_SCHEMA; } // Convert an python index (which may be negative) into an index usable for a // C++ container -// NOLINTNEXTLINE(clang-diagnostic-unused-function) int64_t normalizeIndex(int64_t idx, int64_t list_size) { if (idx < 0) { // Handle negative indexing @@ -2418,8 +2396,7 @@ static const std::vector opGenArgs1{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("prim::rangelist(int n) -> int[]"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t n; + int64_t n = 0; pop(stack, n); c10::List elems; elems.reserve(n); @@ -2458,10 +2435,8 @@ static const std::vector opGenArgs1{ "aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack& stack) { at::Tensor self; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool non_blocking; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool copy; + bool non_blocking = false; + bool copy = false; pop(stack, self, non_blocking, copy); std::optional device = std::nullopt; std::optional scalarType = std::nullopt; @@ -3077,25 +3052,20 @@ static const std::vector opGenArgs2{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::modf(float a) -> (float, float)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; + double a = 0; pop(stack, a); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double b, c; - b = modf(a, &c); + double c = 0; + double b = modf(a, &c); push(stack, b, c); }, aliasAnalysisFromSchema()), OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::frexp(float a) -> (float, int)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; + double a = 0; pop(stack, a); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double m; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int e; + double m = 0; + int e = 0; m = std::frexp(a, &e); push(stack, m, e); }, @@ -3103,10 +3073,8 @@ static const std::vector opGenArgs2{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::ldexp(float x, int i) -> float"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t b; + double a = 0; + int64_t b = 0; pop(stack, a, b); push(stack, std::ldexp(a, b)); }, @@ -3388,8 +3356,7 @@ static const std::vector opGenArgs2{ OperatorGeneratorArgs( TORCH_SELECTIVE_SCHEMA("aten::divmod.int(int x, int y) -> (int, int)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t a, b; + int64_t a = 0, b = 0; lldiv_t divresult = {}; pop(stack, a, b); if (b == 0) { @@ -3411,8 +3378,7 @@ static const std::vector opGenArgs2{ TORCH_SELECTIVE_SCHEMA( "aten::divmod.float(float x, float y) -> (float, float)"), [](Stack& stack) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a, b; + double a = 0, b = 0; pop(stack, a, b); if (b == 0) { throw std::runtime_error("ZeroDivisionError: float divmod()"); diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 63fdee6de8042c..5fa06b09274519 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -16,7 +16,6 @@ #include #include -#include #include namespace torch::jit { @@ -187,8 +186,7 @@ template void createTensorFromList(Stack& stack) { // torch.tensor has a fourth requires_grad arg but torch.as_tensor not, so // we use the template arg to distinguish between these two cases - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool requires_grad; + bool requires_grad = false; IValue data; IValue dtype; IValue device; @@ -334,10 +332,8 @@ RegisterOperators reg({ [](Stack& stack) { at::Tensor weight; at::Tensor input; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double max_norm; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double norm_type; + double max_norm = 0; + double norm_type = 0; pop(stack, weight, input, max_norm, norm_type); // TODO: remove when script supports setting grad mode @@ -402,13 +398,11 @@ RegisterOperators reg({ torch::NoGradGuard no_grad; at::Tensor tensor; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double a; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double b; std::optional generator = pop(stack).toOptional(); + double a = 0; + double b = 0; pop(stack, tensor, a, b); push(stack, tensor.uniform_(a, b, generator)); }, @@ -421,10 +415,8 @@ RegisterOperators reg({ torch::NoGradGuard no_grad; at::Tensor tensor; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double mean; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double std; + double mean = 0; + double std = 0; std::optional generator = pop(stack).toOptional(); @@ -440,8 +432,7 @@ RegisterOperators reg({ torch::NoGradGuard no_grad; at::Tensor tensor; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double val; + double val = 0; pop(stack, tensor, val); push(stack, at::fill_(tensor, val)); }, diff --git a/torch/csrc/jit/runtime/script_profile.cpp b/torch/csrc/jit/runtime/script_profile.cpp index c31f27223b8b7b..3ad4716d32b598 100644 --- a/torch/csrc/jit/runtime/script_profile.cpp +++ b/torch/csrc/jit/runtime/script_profile.cpp @@ -147,7 +147,7 @@ const ScriptProfile::SourceMap& ScriptProfile::dumpStats() { for (const auto& datapoint : datapoints_) { if (const auto& source = datapoint->sourceRange.source()) { if (auto fileLineCol = datapoint->sourceRange.file_line_col()) { - auto it = sourceMap_.find(*source.get()); + auto it = sourceMap_.find(*source); if (it == sourceMap_.end()) { it = sourceMap_.emplace(SourceRef{source}, LineMap{}).first; } diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index c87d3d9b069ca2..f52c24e9f01ff6 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -271,8 +271,7 @@ void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size) { while (any_changed) { any_changed = false; for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - bool changed; + bool changed = false; std::tie(it, changed) = scanNode(*it, aliasDb); any_changed |= changed; } diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 589691eeae6584..15f22bee7dfc0c 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -399,8 +399,6 @@ ManagedTensorRanges::ManagedTensorRanges( const AliasDb& alias_db, const c10::FastSet& managed_tensor_values) { const std::vector nodes(block.nodes().begin(), block.nodes().end()); - const c10::FastSet graph_inputs( - block.inputs().begin(), block.inputs().end()); const auto num_nodes = static_cast(nodes.size()); for (const auto i : c10::irange(num_nodes)) { @@ -457,8 +455,7 @@ ManagedTensorRanges::ManagedTensorRanges( for (auto* managed_tensor : managed_tensor_values) { auto* lifetime = getLifetime(managed_tensor); DCHECK(lifetime && lifetime->end <= num_nodes); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Node* freeing_node; + Node* freeing_node = nullptr; if (lifetime->end == num_nodes) { freeing_node = block.return_node(); } else { diff --git a/torch/csrc/jit/runtime/symbolic_script.cpp b/torch/csrc/jit/runtime/symbolic_script.cpp index 92d901e43a5d21..beecdb770cba37 100644 --- a/torch/csrc/jit/runtime/symbolic_script.cpp +++ b/torch/csrc/jit/runtime/symbolic_script.cpp @@ -1556,12 +1556,12 @@ static void loadModule(const CompilationUnit& module) { Node* forward_tuple = pair.forward->outputs().at(0)->node(); if (forward_tuple->kind() != prim::TupleConstruct) { - throw ErrorReport(forward_tuple->sourceRange()) - << "gradient must return literal a tuple"; + throw( + ErrorReport(forward_tuple->sourceRange()) + << "gradient must return literal a tuple"); } - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Value* context; + Value* context = nullptr; std::tie(pair.backward, context) = extractClosure(forward_tuple->inputs().back()); diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 837ce21095adeb..74f87e46757eaf 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -219,7 +219,7 @@ void checkInputAndOutputTypes( void transformShapeFunction( const FunctionSchema* schema_string, - std::shared_ptr graph) { + const std::shared_ptr& graph) { Inline(*graph); // ATEN operators can return multiple unboxed values, this in contrast to @@ -411,7 +411,7 @@ TORCH_API std::optional boundedGraphsForSchema( void RegisterShapeComputeGraphForSchema( const FunctionSchema& schema, - std::shared_ptr g) { + const std::shared_ptr& g) { std::lock_guard guard(lock); if (cached_schema_to_graph.empty()) { loadFunctions(); diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.h b/torch/csrc/jit/runtime/symbolic_shape_registry.h index a14d327aab4291..7222fd8bca3269 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.h +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.h @@ -52,7 +52,7 @@ struct BoundedShapeGraphs { TORCH_API void RegisterShapeComputeGraphForSchema( const FunctionSchema& schema, - std::shared_ptr g); + const std::shared_ptr& g); TORCH_API std::optional> shapeComputeGraphForSchema( const FunctionSchema& schema); diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index 07eca82df2df54..e5dbae392ccb47 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -149,7 +149,6 @@ mobile::Code compileGraphToMobileCode( // operator names std::vector method_names; - std::vector op_debug_handles; int next_new_op_index = 0; auto op_to_specified_args = code.op_to_num_specified_args(); diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index f201ce49bd7683..ee83ff78444f04 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -518,7 +518,6 @@ flatbuffers::Offset FlatbufferSerializer:: } else { size_t num_attr = class_ptr->numAttributes(); std::vector> names; - std::vector type_index; for (size_t i = 0; i < num_attr; ++i) { names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i))); } diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index ffe0c8330708e1..9be9b0fb2d8c1d 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -94,7 +94,6 @@ enum class PickleOpCode : char { using ::c10::IValue; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct WriteableTensorData { const char* data() const { return static_cast(tensor_.storage().data()); @@ -140,7 +139,6 @@ class TORCH_API Pickler { memoized_class_types_(memoized_class_types), get_tensor_id_(std::move(get_tensor_id)), tag_aggregates_(tag_aggregates) {} - // NOLINTNEXTLINE(bugprone-exception-escape) ~Pickler(); // Push protocol onto the stack diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index b818051bf14a88..e1a42cb1d45935 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -244,13 +244,12 @@ class RegisterCodeGen { RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance(); codegen_list.AddStmtFactoryMethod( name, - [](StmtPtr stmt, + [](const StmtPtr& stmt, const std::vector& params, at::Device device, const std::string& kernel_func_name) { - std::unique_ptr method( - new CodeGenType(stmt, params, device, kernel_func_name)); - return method; + return std::make_unique( + stmt, params, device, kernel_func_name); }); } }; diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 58876fb6a5e75e..a3d2274a1eccf0 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -1044,7 +1044,7 @@ class SimpleIREvaluatorImpl : public IRVisitor { v->buffer_var()->name_hint()); } buffer_mapping_[b] = buffer->data(); - internal_buffers_.insert(std::make_pair(b, std::move(buffer))); + internal_buffers_.emplace(std::move(b), std::move(buffer)); } void visit(const PlacementAllocatePtr& v) override { diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index a498408a37c3e0..35c6ac03ce8dda 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -552,7 +552,7 @@ bool Buf::is_stride_one(int cur_dim) const { return exprEquals(strides_[cur_dim], alloc(1)); } -ExprHandle expr_to_vec(ExprHandle v, int lanes) { +ExprHandle expr_to_vec(const ExprHandle& v, int lanes) { if (lanes == 1) { return v; } else { diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h index fbef2b134057e3..c5c41cd0a045ce 100644 --- a/torch/csrc/jit/tensorexpr/expr.h +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -488,6 +488,6 @@ TORCH_API ExprHandle Relu(const ExprHandle& v1); TORCH_API ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f); -TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes); +TORCH_API ExprHandle expr_to_vec(const ExprHandle& v, int lanes); } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index decfe0bceb3215..5fe00d7959c408 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1365,10 +1365,10 @@ void nnc_aten_addmm( const at::Tensor& y = tensors[2]; const at::Tensor& z = tensors[3]; // TODO: handle other alpha and beta dtypes, e.g. alpha=0.6, beta=0.2 - int64_t alpha = extra_args[0], beta = extra_args[1]; + int64_t beta = extra_args[0], alpha = extra_args[1]; try { - at::addmm_out(r, x, y, z, alpha, beta); + at::addmm_out(r, x, y, z, beta, alpha); } catch (...) { } } diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 76cb0b0daec0c4..cdd2c0e66bf7ab 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -2885,7 +2885,6 @@ ExprPtr SimplifierUnderContext::mutate(const DivPtr& v) { ExprPtr lhs = v->lhs(); ExprPtr rhs = v->rhs(); - std::ostringstream oss; if (auto ret = distributeDiv(lhs, rhs, var_bound_info_)) { GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret); return ret->accept_mutator(this); @@ -3005,7 +3004,6 @@ ExprPtr SimplifierUnderContext::mutate(const ModPtr& v) { ExprPtr lhs = v->lhs(); ExprPtr rhs = v->rhs(); - std::ostringstream oss; if (auto ret = distributeMod(lhs, rhs, var_bound_info_)) { GRAPH_DEBUG("SimplifierUnderContext: ", *v, " => ", *ret); return ret->accept_mutator(this); diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 4720da2dc59fc9..fa6aa3d70f766e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -984,7 +984,6 @@ TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice( // we use the debug names in printing cuda code, they need to be removed // of characters that can't be used in a variable identifier void TensorExprKernel::genInputDebugNames() { - std::unordered_map name_to_value; std::unordered_set name_set; std::unordered_map value_to_name; for (const torch::jit::Value* input : graph_->inputs()) { @@ -1747,7 +1746,6 @@ void TensorExprKernel::compile() { VarPtr v = t.buf()->base_handle(); scalars_[output] = VarHandle(v); block->append_stmt(t.stmt()); - std::vector dims; BufHandle buf( "scalar_" + sanitizeName(output->debugName()), {}, v->dtype()); StmtPtr store = Store::make(buf, {}, ExprHandle(v)); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 334ed836ebe739..a9f24139e029f3 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -626,6 +626,20 @@ void LLVMCodeGenImpl::emitWrapper(const std::vector& params) { module_.get()); #endif + { + // Work around UBSAN crashes which reads 8 byte in front of every function. + // Otherwise, if the function was placed at the beginning of a page, reading + // 8B before the page could trigger a wild-addr-read ASAN failure if the + // page before this function was not mapped. + // - https://reviews.llvm.org/D148665 + // - https://github.com/llvm/llvm-project/issues/65253 + // Place the variable just before the function, + // the optimizer might otherwise disable this workaround. + // https://llvm.org/docs/LangRef.html#prefix-data + wrapper->setPrefixData(llvm::Constant::getNullValue( + llvm::ArrayType::get(llvm::Type::getInt8Ty(getContext()), 8))); + } + auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper); irb_.SetInsertPoint(wrapBB); llvm::SmallVector wrappedArgs; diff --git a/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp b/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp index cd16de19255803..343bb8e8f30df9 100644 --- a/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest_randomization.cpp @@ -480,7 +480,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { } int index = rand() % (int)all_nested_loops.size(); - auto nested_loops = all_nested_loops.at(index); + auto const& nested_loops = all_nested_loops.at(index); if (nested_loops.size() < 2) { break; } @@ -554,7 +554,7 @@ void loopnestRandomization(int64_t seed, LoopNest& l) { // Randomly pick a set of consecutive loops to flatten int index = rand() % (int)all_nested_loops.size(); - auto nested_loops = all_nested_loops.at(index); + auto const& nested_loops = all_nested_loops.at(index); // Generate a good history message std::vector indices; diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index bfce006d55177e..2dd46335e09f8e 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace jit { -namespace tensorexpr { +namespace torch::jit::tensorexpr { namespace { @@ -20,8 +18,8 @@ void assert_dims_constant(const BufHandle& buf) { using InitFunc = std::function&)>; Tensor conv2d_depthwise_static( - BufHandle input, - BufHandle weight, + const BufHandle& input, + const BufHandle& weight, const InitFunc& init_func, int stride, int pad, @@ -78,14 +76,12 @@ Tensor conv2d_depthwise_static( constexpr int kLoopH = 2, kLoopW = 3; if (R == 3 && stride == 2 && pad == 1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr head, tail; auto loops = nest.getLoopStmtsFor(conv); nest.sliceHead(loops[kLoopW], 2, &head, &tail); loops = nest.getLoopStmtsFor(conv); nest.sliceHead(loops[kLoopH], 2, &head, &tail); } else if (R == 3 && stride == 1 && pad == 1) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) ForPtr main, peeled; auto loops = nest.getAllLoopNestsWritingToBuf(conv.buf()); main = loops[1][kLoopW]; @@ -489,6 +485,4 @@ Tensor computeMkldnnPrepackedConvRun( return Tensor(ResultBuf.node(), s); } -} // namespace tensorexpr -} // namespace jit -} // namespace torch +} // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.cpp b/torch/csrc/jit/tensorexpr/operators/reduction.cpp index 93030a1a8e045c..3e189e7acab9ac 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/operators/reduction.cpp @@ -146,7 +146,6 @@ Tensor computeMax( } BufHandle ResultBuf("max", outputShape, dtype); BufHandle InputBuf = std::get(inputs[0]); - std::vector max_dims_expr; auto max_dim = std::get(inputs[1]); auto keep_dim = std::get(inputs[2]); return Tensor( diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index fc871b73bbb3c4..62bde5c4c33ea1 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -84,51 +84,47 @@ class TORCH_API Block : public StmtNode { return stmts_.empty(); } - void prepend_stmt(StmtPtr s) { + void prepend_stmt(const StmtPtr& s) { if (s->get_parent()) { - throw malformed_input( - "Block prepend Stmt with existing parent", std::move(s)); + throw malformed_input("Block prepend Stmt with existing parent", s); } stmts_.push_front(s); set_parent(s, this); } - void append_stmt(StmtPtr s) { + void append_stmt(const StmtPtr& s) { if (s->get_parent()) { - throw malformed_input( - "Block append Stmt with existing parent", std::move(s)); + throw malformed_input("Block append Stmt with existing parent", s); } stmts_.push_back(s); set_parent(s, this); } - void insert_stmt_before(StmtPtr s, const StmtPtr& before) { + void insert_stmt_before(const StmtPtr& s, const StmtPtr& before) { if (s->get_parent()) { - throw malformed_input( - "Block append Stmt with existing parent", std::move(s)); + throw malformed_input("Block append Stmt with existing parent", s); } auto pos = std::find(stmts_.begin(), stmts_.end(), before); if (pos == stmts_.end()) { throw malformed_input( - "Inserting after statement that is not in block", std::move(s)); + "Inserting after statement that is not in block", s); } stmts_.insert(pos, s); set_parent(s, this); } - void insert_stmt_after(StmtPtr s, const StmtPtr& after) { + void insert_stmt_after(const StmtPtr& s, const StmtPtr& after) { if (s->get_parent()) { - throw malformed_input( - "Block append Stmt with existing parent", std::move(s)); + throw malformed_input("Block append Stmt with existing parent", s); } auto pos = std::find(stmts_.begin(), stmts_.end(), after); if (pos == stmts_.end()) { throw malformed_input( - "Inserting after statement that is not in block", std::move(s)); + "Inserting after statement that is not in block", s); } ++pos; @@ -137,10 +133,10 @@ class TORCH_API Block : public StmtNode { set_parent(s, this); } - bool replace_stmt(const StmtPtr& old_stmt, StmtPtr new_stmt) { + bool replace_stmt(const StmtPtr& old_stmt, const StmtPtr& new_stmt) { if (new_stmt->get_parent()) { throw malformed_input( - "Block replace Stmt with existing parent", std::move(new_stmt)); + "Block replace Stmt with existing parent", new_stmt); } auto pos = std::find(stmts_.begin(), stmts_.end(), old_stmt); @@ -157,10 +153,10 @@ class TORCH_API Block : public StmtNode { // Creates a new block by cloning `this` block and replacing the given // statement with a new statement. Note that `old_stmt` refers to a statement // in `this` block. If the `old_stmt` is not found, it will return `nullptr`. - BlockPtr clone_and_replace(const StmtPtr& old_stmt, StmtPtr new_stmt) { + BlockPtr clone_and_replace(const StmtPtr& old_stmt, const StmtPtr& new_stmt) { if (new_stmt->get_parent()) { throw malformed_input( - "Block replace Stmt with existing parent", std::move(new_stmt)); + "Block replace Stmt with existing parent", new_stmt); } std::vector stmts(stmts_.begin(), stmts_.end()); @@ -260,9 +256,7 @@ class TORCH_API Block : public StmtNode { StmtPtr p1_p = std::move(p1); while (p1_p) { if (BlockPtr b = to(p1_p)) { - if (b) { - enclosing.insert(b); - } + enclosing.insert(b); } p1_p = p1_p->get_parent(); } diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp index f74c7eb37a9df6..8f344e001e31ff 100644 --- a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp @@ -31,7 +31,7 @@ const std::string& UniqueNameManager::get_unique_name(const VarPtr& v) { } if (all_unique_names_.count(unique_name) == 0) { all_unique_names_.insert(unique_name); - auto result = unique_name_mapping_.insert(std::make_pair(v, unique_name)); + auto result = unique_name_mapping_.emplace(v, unique_name); return result.first->second; } } diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 027eb2aa0acf6b..e9bf764c31575d 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -228,6 +228,7 @@ struct FileCheckImpl { groups.push_back({check}); } } + checks.push_back(check); has_run = false; } diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index cf2af60a1572df..0de566f3cf10be 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -80,6 +80,12 @@ void initModule(PyObject* module) { at::detail::getMTIAHooks().memoryStats(device_index); return py::reinterpret_steal(raw_pyobject); }); + + m.def("_mtia_getDeviceCapability", [](c10::DeviceIndex device_index) { + PyObject* raw_pyobject = + at::detail::getMTIAHooks().getDeviceCapability(device_index); + return py::reinterpret_steal(raw_pyobject); + }); } } // namespace mtia diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 3bfd0a4b8f3d9b..67cfaab3b81957 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -33,11 +33,17 @@ RawTensorMetadataBase::RawTensorMetadataBase(const at::Tensor& t) : data_{t.has_storage() ? t.storage().data() : nullptr}, dtype_{t.scalar_type()}, layout_{t.layout()}, - dim_{static_cast(t.sizes().size())} { + size_dim_{static_cast(t.sizes().size())} { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( t.sizes().size() <= std::numeric_limits::max(), "Cannot profile Tensors of size > uint32 max. Got dim: ", t.sizes().size()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + t.sizes().size() == t.strides().size(), + "Tensor has mismatching sizes and strides. Sizes: ", + t.sizes().size(), + " Strides: ", + t.strides().size()); } RawTensorMetadata::RawTensorMetadata(const at::Tensor& t) @@ -181,14 +187,29 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { ivals_it = ivalues_.begin(), io_type]() mutable { auto decode_tensor = [&]() -> TensorMetadata { - const auto& raw_metadata = *tensor_metadata_it++; std::vector sizes; std::vector strides; - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) { + if (tensor_metadata_it.exhausted()) { + LOG(WARNING) + << "Tensor metadata exhausted prematurely. Reported shapes may be inaccurate!"; + return {RawTensorMetadata(), sizes, strides}; + } + const auto& raw_metadata = *tensor_metadata_it++; + for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) { + if (tensor_size_strides_it.exhausted()) { + LOG(WARNING) + << "Expected Tensor Size mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; + return {raw_metadata, sizes, strides}; + } sizes.push_back(*tensor_size_strides_it++); } if (raw_metadata.layout_ == at::kStrided) { - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.dim_)) { + for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) { + if (tensor_size_strides_it.exhausted()) { + LOG(WARNING) + << "Expected Tensor Strides mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; + return {raw_metadata, sizes, strides}; + } strides.push_back(*tensor_size_strides_it++); } } diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 716fdb910c01ea..abaa9a845082b6 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -47,7 +47,7 @@ struct TORCH_API RawTensorMetadataBase { StorageImplData data_; c10::ScalarType dtype_{c10::ScalarType::Undefined}; c10::Layout layout_{c10::Layout::Strided}; - uint32_t dim_{0}; + uint32_t size_dim_{0}; }; // Collected during profiling. diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 25f93a2663dfb5..661646920632e2 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -441,7 +441,7 @@ void initPythonBindings(PyObject* module) { return py::reinterpret_borrow( torch::autograd::utils::wrap(metadata.dtype_)); }) - .def_readonly("dim", &TensorMetadata::dim_) + .def_readonly("dim", &TensorMetadata::size_dim_) .def_readonly("sizes", &TensorMetadata::sizes_) .def_readonly("strides", &TensorMetadata::strides_); diff --git a/torch/csrc/profiler/stubs/itt.cpp b/torch/csrc/profiler/stubs/itt.cpp index e131e19adc4f4f..89d7c32657c278 100644 --- a/torch/csrc/profiler/stubs/itt.cpp +++ b/torch/csrc/profiler/stubs/itt.cpp @@ -4,9 +4,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") -namespace torch { -namespace profiler { -namespace impl { +namespace torch::profiler::impl { namespace { struct ITTMethods : public ProfilerStubs { @@ -51,7 +49,5 @@ struct RegisterITTMethods { RegisterITTMethods reg; } // namespace -} // namespace impl -} // namespace profiler -} // namespace torch -C10_CLANG_DIAGNOSTIC_POP() +} // namespace torch::profiler::impl +C10_DIAGNOSTIC_POP() diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 8ff3910b6cd612..22ddf02d8452ea 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -7,29 +7,29 @@ !__has_include("ext/stdio_filebuf.h") namespace torch::unwind { std::vector unwind() { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } std::optional> libraryFor(void* addr) { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } #ifndef FBCODE_CAFFE2 std::vector symbolize(const std::vector& frames, Mode mode) { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } #endif Stats stats() { - TORCH_CHECK( - false, + TORCH_WARN_ONCE( "record_context_cpp is not support on non-linux non-x86_64 platforms"); + return {}; } } // namespace torch::unwind diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 6e582b25c5ebb9..0542309f81cf95 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -293,17 +293,29 @@ std::string strListToStr(const std::vector& types) { return "[" + rc + "]"; } } -std::string ivalueToStr(const c10::IValue& val) { +std::string ivalueToStr(const c10::IValue& val, bool isString) { std::stringstream ss; if (val.isNone()) { return "\"None\""; } else { ss.str(""); - ss << "\""; + if (isString) { + ss << "\""; + } ss << val; - ss << "\""; + if (isString) { + ss << "\""; + } std::string mystr = ss.str(); + // For boolean the values that ivalue gives is "True" and "False" but + // json only takes "true" and "false" so we convert the string to lower case + if (val.isBool()) { + for (char& c : mystr) { + c = std::tolower(c); + } + } + // A double quote can cause issues with the chrome tracing so force // all inputs to not contain more than the 2 we add in this function int count = std::count(mystr.begin(), mystr.end(), '\"'); diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index f68d0e7b1aadb8..ba149e740b221f 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -92,7 +92,7 @@ TORCH_API std::string shapesToStr( TORCH_API std::string strListToStr(const std::vector& types); TORCH_API std::string inputOpIdsToStr( const std::list>& input_op_ids); -TORCH_API std::string ivalueToStr(const c10::IValue& val); +TORCH_API std::string ivalueToStr(const c10::IValue& val, bool isString); TORCH_API std::string ivalueListToStr(const std::vector& list); TORCH_API std::vector inputTypes(const at::RecordFunction& fn); diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index 10ba23a656f08c..c922da900613d9 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -254,7 +254,7 @@ void THPStorage_writeFileRaw( doWrite(fd, &numel, sizeof(int64_t)); else { int64_t nsize{}; // convert big endian cpu to little endian storage - torch::utils::THP_encodeInt64Buffer( + torch::utils::THP_encodeBuffer( (uint8_t*)&nsize, (const int64_t*)&numel, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, @@ -274,19 +274,19 @@ void THPStorage_writeFileRaw( for (size_t i = 0; i < numel; i += buffer_size) { size_t to_convert = std::min(numel - i, buffer_size); if (element_size == 2) { - torch::utils::THP_encodeInt16Buffer( + torch::utils::THP_encodeBuffer( le_buffer.data(), (const int16_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (element_size == 4) { - torch::utils::THP_encodeInt32Buffer( + torch::utils::THP_encodeBuffer( le_buffer.data(), (const int32_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (element_size == 8) { - torch::utils::THP_encodeInt64Buffer( + torch::utils::THP_encodeBuffer( le_buffer.data(), (const int64_t*)data + i, torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, @@ -322,7 +322,7 @@ c10::intrusive_ptr THPStorage_readFileRaw( if (torch::utils::THP_nativeByteOrder() == torch::utils::THPByteOrder::THP_BIG_ENDIAN) { int64_t tsize = size; // convert little endian storage to big endian cpu - torch::utils::THP_decodeInt64Buffer(&size, (const uint8_t*)&tsize, true, 1); + torch::utils::THP_decodeBuffer(&size, (const uint8_t*)&tsize, true, 1); } size_t nbytes = element_size * size; if (!storage.defined()) { @@ -369,13 +369,13 @@ c10::intrusive_ptr THPStorage_readFileRaw( // NOLINTNEXTLINE(bugprone-branch-clone) if (element_size == 2) { - torch::utils::THP_decodeInt16Buffer( + torch::utils::THP_decodeBuffer( (int16_t*)data + i, le_buffer.get(), true, to_convert); } else if (element_size == 4) { - torch::utils::THP_decodeInt32Buffer( + torch::utils::THP_decodeBuffer( (int32_t*)data + i, le_buffer.get(), true, to_convert); } else if (element_size == 8) { - torch::utils::THP_decodeInt64Buffer( + torch::utils::THP_decodeBuffer( (int64_t*)data + i, le_buffer.get(), true, to_convert); } } diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index fda2c45a9c88b3..cbac2367500b3f 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -267,7 +267,14 @@ char* tensor_repr(at::Tensor tensor) { const char* buf = nullptr; char* result = nullptr; - pytensor = THPVariable_Wrap(std::move(tensor)); + // NB: It's important not to move the tensor into THPVariable_Wrap, + // because this function is only called from our gdb macros, and + // we want to avoid accidentally moving out the tensor. In principle, + // the Tensor signature above should induce a copy, but we've + // observed that sometimes gdb passes the outer Tensor address exactly as is + // into this function. + // See https://github.com/pytorch/pytorch/issues/134762 + pytensor = THPVariable_Wrap(tensor); if (!pytensor) // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) goto error; diff --git a/torch/csrc/utils/byte_order.cpp b/torch/csrc/utils/byte_order.cpp index 6c41253fbb8664..4432e74bd06c10 100644 --- a/torch/csrc/utils/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -117,49 +117,38 @@ THPByteOrder THP_nativeByteOrder() { return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN; } -void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len) { - for (const auto i : c10::irange(len)) { - dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src) - : decodeUInt16(src)); - src += sizeof(int16_t); - } -} - -void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len) { - for (const auto i : c10::irange(len)) { - dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src) - : decodeUInt32(src)); - src += sizeof(int32_t); - } -} +template +void THP_decodeBuffer(T* dst, const uint8_t* src, U type, size_t len) { + if constexpr (std::is_same_v) + THP_decodeBuffer(dst, src, type != THP_nativeByteOrder(), len); + else { + auto func = [&](const uint8_t* src_data) { + if constexpr (std::is_same_v) { + return type ? decodeUInt16ByteSwapped(src_data) + : decodeUInt16(src_data); + } else if constexpr (std::is_same_v) { + return type ? decodeUInt32ByteSwapped(src_data) + : decodeUInt32(src_data); + } else if constexpr (std::is_same_v) { + return type ? decodeUInt64ByteSwapped(src_data) + : decodeUInt64(src_data); + } + }; -void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len) { - for (const auto i : c10::irange(len)) { - dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src) - : decodeUInt64(src)); - src += sizeof(int64_t); + for (const auto i : c10::irange(len)) { + dst[i] = static_cast(func(src)); + src += sizeof(T); + } } } -void THP_decodeHalfBuffer( +template <> +TORCH_API void THP_decodeBuffer( c10::Half* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint16_t x; c10::Half f; @@ -170,7 +159,8 @@ void THP_decodeHalfBuffer( } } -void THP_decodeBFloat16Buffer( +template <> +TORCH_API void THP_decodeBuffer( at::BFloat16* dst, const uint8_t* src, bool do_byte_swap, @@ -183,19 +173,24 @@ void THP_decodeBFloat16Buffer( } } -void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len) { +template <> +TORCH_API void THP_decodeBuffer( + bool* dst, + const uint8_t* src, + bool, + size_t len) { for (const auto i : c10::irange(len)) { dst[i] = (int)src[i] != 0 ? true : false; } } -void THP_decodeFloatBuffer( +template <> +TORCH_API void THP_decodeBuffer( float* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float f; @@ -206,13 +201,13 @@ void THP_decodeFloatBuffer( } } -void THP_decodeDoubleBuffer( +template <> +TORCH_API void THP_decodeBuffer( double* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double d; @@ -223,18 +218,17 @@ void THP_decodeDoubleBuffer( } } -void THP_decodeComplexFloatBuffer( +template <> +TORCH_API void THP_decodeBuffer, bool>( c10::complex* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float re; }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t y; float im; @@ -249,18 +243,17 @@ void THP_decodeComplexFloatBuffer( } } -void THP_decodeComplexDoubleBuffer( +template <> +TORCH_API void THP_decodeBuffer, bool>( c10::complex* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double re; }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t y; double im; @@ -276,150 +269,46 @@ void THP_decodeComplexDoubleBuffer( } } -void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} +#define DEFINE_DECODE(TYPE, ORDER) \ + template TORCH_API void THP_decodeBuffer( \ + TYPE * dst, const uint8_t* src, ORDER type, size_t len); -void THP_decodeHalfBuffer( - c10::Half* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeBFloat16Buffer( - at::BFloat16* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len); -} +DEFINE_DECODE(int16_t, THPByteOrder) +DEFINE_DECODE(int32_t, THPByteOrder) +DEFINE_DECODE(int64_t, THPByteOrder) +DEFINE_DECODE(c10::Half, THPByteOrder) +DEFINE_DECODE(float, THPByteOrder) +DEFINE_DECODE(double, THPByteOrder) +DEFINE_DECODE(c10::BFloat16, THPByteOrder) +DEFINE_DECODE(c10::complex, THPByteOrder) +DEFINE_DECODE(c10::complex, THPByteOrder) -void THP_decodeFloatBuffer( - float* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} +DEFINE_DECODE(int16_t, bool) +DEFINE_DECODE(int32_t, bool) +DEFINE_DECODE(int64_t, bool) -void THP_decodeDoubleBuffer( - double* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} +#undef DEFINE_DECODE -void THP_decodeComplexFloatBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_decodeComplexDoubleBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len) { - THP_decodeComplexDoubleBuffer( - dst, src, (order != THP_nativeByteOrder()), len); -} - -void THP_encodeInt16Buffer( - uint8_t* dst, - const int16_t* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(int16_t) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes16(dst); - dst += sizeof(int16_t); - } - } -} - -void THP_encodeInt32Buffer( - uint8_t* dst, - const int32_t* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(int32_t) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes32(dst); - dst += sizeof(int32_t); - } - } -} - -void THP_encodeInt64Buffer( - uint8_t* dst, - const int64_t* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(int64_t) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes64(dst); - dst += sizeof(int64_t); - } - } -} - -void THP_encodeFloatBuffer( - uint8_t* dst, - const float* src, - THPByteOrder order, - size_t len) { - memcpy(dst, src, sizeof(float) * len); - if (order != THP_nativeByteOrder()) { - for (const auto i : c10::irange(len)) { - (void)i; - swapBytes32(dst); - dst += sizeof(float); - } - } -} - -void THP_encodeDoubleBuffer( +template +void THP_encodeBuffer( uint8_t* dst, - const double* src, + const T* src, THPByteOrder order, size_t len) { - memcpy(dst, src, sizeof(double) * len); + memcpy(dst, src, sizeof(T) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; - swapBytes64(dst); - dst += sizeof(double); + if constexpr (std::is_same_v) { + swapBytes16(dst); + } else if constexpr ( + std::is_same_v || std::is_same_v) { + swapBytes32(dst); + } else if constexpr ( + std::is_same_v || std::is_same_v) { + swapBytes64(dst); + } + dst += sizeof(T); } } } @@ -436,7 +325,8 @@ std::vector complex_to_float(const c10::complex* src, size_t len) { return new_src; } -void THP_encodeComplexFloatBuffer( +template <> +TORCH_API void THP_encodeBuffer>( uint8_t* dst, const c10::complex* src, THPByteOrder order, @@ -452,7 +342,8 @@ void THP_encodeComplexFloatBuffer( } } -void THP_encodeComplexDoubleBuffer( +template <> +TORCH_API void THP_encodeBuffer>( uint8_t* dst, const c10::complex* src, THPByteOrder order, @@ -468,4 +359,16 @@ void THP_encodeComplexDoubleBuffer( } } +#define DEFINE_ENCODE(TYPE) \ + template TORCH_API void THP_encodeBuffer( \ + uint8_t * dst, const TYPE* src, THPByteOrder order, size_t len); + +DEFINE_ENCODE(int16_t) +DEFINE_ENCODE(int32_t) +DEFINE_ENCODE(int64_t) +DEFINE_ENCODE(float) +DEFINE_ENCODE(double) + +#undef DEFINE_ENCODE + } // namespace torch::utils diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h index 7fdbe4da8fe44f..d586270fbfc7c2 100644 --- a/torch/csrc/utils/byte_order.h +++ b/torch/csrc/utils/byte_order.h @@ -68,148 +68,13 @@ enum THPByteOrder { THP_LITTLE_ENDIAN = 0, THP_BIG_ENDIAN = 1 }; TORCH_API THPByteOrder THP_nativeByteOrder(); -TORCH_API void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeHalfBuffer( - c10::Half* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeFloatBuffer( - float* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeDoubleBuffer( - double* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len); -TORCH_API void THP_decodeBFloat16Buffer( - at::BFloat16* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeComplexFloatBuffer( - c10::complex* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); -TORCH_API void THP_decodeComplexDoubleBuffer( - c10::complex* dst, - const uint8_t* src, - bool do_byte_swap, - size_t len); - -TORCH_API void THP_decodeInt16Buffer( - int16_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeInt32Buffer( - int32_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeInt64Buffer( - int64_t* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeHalfBuffer( - c10::Half* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeFloatBuffer( - float* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeDoubleBuffer( - double* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeBFloat16Buffer( - at::BFloat16* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeFloat8_e5m2Buffer( - at::Float8_e5m2* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeFloat8_e4m3fnBuffer( - at::Float8_e4m3fn* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeFloat8_e5m2fnuzBuffer( - at::Float8_e5m2fnuz* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeFloat8_e4m3fnuzBuffer( - at::Float8_e4m3fnuz* dst, - const uint8_t* src, - size_t len); -TORCH_API void THP_decodeComplexFloatBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_decodeComplexDoubleBuffer( - c10::complex* dst, - const uint8_t* src, - THPByteOrder order, - size_t len); +template +TORCH_API void THP_decodeBuffer(T* dst, const uint8_t* src, U type, size_t len); -TORCH_API void THP_encodeInt16Buffer( - uint8_t* dst, - const int16_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeInt32Buffer( - uint8_t* dst, - const int32_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeInt64Buffer( - uint8_t* dst, - const int64_t* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeFloatBuffer( - uint8_t* dst, - const float* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeDoubleBuffer( - uint8_t* dst, - const double* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeComplexFloatBuffer( - uint8_t* dst, - const c10::complex* src, - THPByteOrder order, - size_t len); -TORCH_API void THP_encodeComplexDoubleBuffer( +template +TORCH_API void THP_encodeBuffer( uint8_t* dst, - const c10::complex* src, + const T* src, THPByteOrder order, size_t len); diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index 79c05f3c9ada77..c0147977ead29e 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -28,7 +28,8 @@ void set_requires_device_init(at::DeviceType device_type, bool value); inline void maybe_initialize_device(at::Device& device) { // Add more devices here to enable lazy initialization. - if (device.is_cuda() || device.is_xpu() || device.is_privateuseone()) { + if (device.is_cuda() || device.is_xpu() || device.is_privateuseone() || + device.is_hpu()) { device_lazy_init(device.type()); } } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 455d1c95ac571c..aa875680788671 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -726,6 +726,7 @@ void initDispatchBindings(PyObject* module) { DEF_ONE(PreDispatch) DEF_ONE(Functionalize) DEF_ONE(AutocastCPU) + DEF_ONE(AutocastMPS) DEF_ONE(AutocastXPU) DEF_ONE(AutocastHPU) DEF_ONE(AutocastIPU) diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 10fbcfc41b2d74..de58b1965492da 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -219,37 +219,13 @@ void recursive_store( auto new_obj = py::reinterpret_borrow(obj); auto val = new_obj.cast(); const double double_val = val.guard_float(__FILE__, __LINE__); - switch (elementSize) { - case 8: - *reinterpret_cast(data) = double_val; - break; - case 4: - *reinterpret_cast(data) = static_cast(double_val); - break; - } - return; + obj = Py_BuildValue("d", double_val); } if (is_symint) { auto new_obj = py::reinterpret_borrow(obj); auto val = new_obj.cast(); - const auto int_val = val.guard_int(__FILE__, __LINE__); - switch (elementSize) { - case 8: - *reinterpret_cast(data) = int_val; - break; - case 4: - *reinterpret_cast(data) = static_cast(int_val); - break; - case 2: - *reinterpret_cast(data) = static_cast(int_val); - break; - case 1: - *reinterpret_cast(data) = static_cast(int_val); - break; - default: - TORCH_CHECK(false, "Unexpected elementSize ", elementSize); - } - return; + const int64_t int_val = val.guard_int(__FILE__, __LINE__); + obj = Py_BuildValue("i", int_val); } torch::utils::store_scalar(data, scalarType, obj); return; diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index d8e49d9d56b9a3..6e6c9a4564b65e 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -197,6 +197,72 @@ PyObject* THXPModule_emptyCache(PyObject* self, PyObject* noargs) { Py_RETURN_NONE; } +PyObject* THXPModule_memoryStats(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + + using c10::CachingDeviceAllocator::DeviceStats; + using c10::CachingDeviceAllocator::Stat; + using c10::CachingDeviceAllocator::StatArray; + using c10::CachingDeviceAllocator::StatType; + + const auto statToDict = [](const Stat& stat) { + py::dict dict; + + dict["current"] = stat.current; + dict["peak"] = stat.peak; + dict["allocated"] = stat.allocated; + dict["freed"] = stat.freed; + return dict; + }; + + const auto statArrayToDict = [=](const StatArray& statArray) { + const std::array(StatType::NUM_TYPES)> + statTypeNames = {"all", "small_pool", "large_pool"}; + py::dict dict; + for (const auto i : c10::irange(statTypeNames.size())) { + dict[statTypeNames[i]] = statToDict(statArray[i]); + } + return dict; + }; + + const DeviceStats stats = + c10::xpu::XPUCachingAllocator::getDeviceStats(device_index); + + py::dict result; + result["allocated_bytes"] = statArrayToDict(stats.allocated_bytes); + result["reserved_bytes"] = statArrayToDict(stats.reserved_bytes); + result["active_bytes"] = statArrayToDict(stats.active_bytes); + result["requested_bytes"] = statArrayToDict(stats.requested_bytes); + + return result.release().ptr(); + END_HANDLE_TH_ERRORS +} + +PyObject* THXPModule_resetPeakMemoryStats(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::xpu::XPUCachingAllocator::resetPeakStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + +PyObject* THXPModule_resetAccumulatedMemoryStats( + PyObject* self, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "invalid argument to reset_accumulated_memory_stats"); + const auto device_index = THPUtils_unpackDeviceIndex(arg); + c10::xpu::XPUCachingAllocator::resetAccumulatedStats(device_index); + END_HANDLE_TH_ERRORS + Py_RETURN_NONE; +} + // XPU module initialization static void registerXpuDeviceProperties(PyObject* module) { @@ -353,6 +419,15 @@ static struct PyMethodDef _THXPModule_methods[] = { nullptr}, {"_xpu_synchronize", THXPModule_xpuSynchronize, METH_O, nullptr}, {"_xpu_emptyCache", THXPModule_emptyCache, METH_NOARGS, nullptr}, + {"_xpu_memoryStats", THXPModule_memoryStats, METH_O, nullptr}, + {"_xpu_resetAccumulatedMemoryStats", + THXPModule_resetAccumulatedMemoryStats, + METH_O, + nullptr}, + {"_xpu_resetPeakMemoryStats", + THXPModule_resetPeakMemoryStats, + METH_O, + nullptr}, {nullptr}}; PyMethodDef* THXPModule_methods() { diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index eb30f8f057ed25..cd25b2ab3de6fd 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1122,7 +1122,8 @@ def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: handle = _get_amdsmi_handler(device) - return amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)["cur_clk"] + clk_info = amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX) + return clk_info["clk"] if "clk" in clk_info else clk_info["cur_clk"] def memory_usage(device: Optional[Union[Device, int]] = None) -> int: @@ -1628,6 +1629,7 @@ def addmm_kernel_impl(*args, **kwargs): "memory_usage", "MemPool", "MemPoolContext", + "use_mem_pool", "temperature", "power_draw", "clock_rate", diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 1726cbe439dc60..af2c8d480c8345 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -52,6 +52,7 @@ "change_current_allocator", "MemPool", "MemPoolContext", + "use_mem_pool", ] @@ -64,8 +65,20 @@ # Define dummy base classes torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool") torch._C.__dict__["_MemPoolContext"] = _dummy_type("_MemPoolContext") + torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( + "_cuda_beginAllocateToPool" + ) + torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type( + "_cuda_endAllocateCurrentStreamToPool" + ) -from torch._C import _cuda_CUDAAllocator, _MemPool, _MemPoolContext # noqa: F401 +from torch._C import ( # noqa: F401 + _cuda_beginAllocateToPool, + _cuda_CUDAAllocator, + _cuda_endAllocateCurrentStreamToPool, + _MemPool, + _MemPoolContext, +) def _host_allocator(): @@ -1002,3 +1015,27 @@ def __init__(self, pool: MemPool): def active_pool() -> Optional[_MemPool]: r"""Returns the active MemPool""" return _MemPoolContext.active_pool() + + +@contextlib.contextmanager +def use_mem_pool(pool: MemPool, device: Union[Device, int] = None): + r"""A context manager that routes allocations to a given pool. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + device (torch.device or int, optional): selected device. Uses MemPool on + the current device, given by :func:`~torch.cuda.current_device`, + if :attr:`device` is ``None`` (default). + + """ + ctx = MemPoolContext(pool) + device_index = ( + torch.cuda.current_device() if device is None else _get_device_index(device) + ) + _cuda_beginAllocateToPool(device_index, pool.id) + try: + yield + finally: + _cuda_endAllocateCurrentStreamToPool(device_index, pool.id) + del ctx diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 1351d781ec91fd..4e10f4594c1529 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -4,8 +4,8 @@ import torch import torch._dynamo.compiled_autograd as ca import torch.distributed as dist -from torch.distributed._tensor import DTensor from torch.distributed.distributed_c10d import ReduceOp +from torch.distributed.tensor import DTensor from ._fsdp_common import ( _get_dim0_padded_size, diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 36b181250f28da..31b74079aaa8be 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -10,8 +10,8 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry -from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec @dataclass diff --git a/torch/distributed/_composable/fsdp/_fsdp_init.py b/torch/distributed/_composable/fsdp/_fsdp_init.py index 1a137d24b86723..c07e323449f30c 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_init.py +++ b/torch/distributed/_composable/fsdp/_fsdp_init.py @@ -4,8 +4,8 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh from torch.utils._python_dispatch import is_traceable_wrapper_subclass from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index 7bd5dd72badabd..fef1c865b6b08b 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -9,14 +9,10 @@ import torch.nn as nn from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DTensor, Replicate, Shard -from torch.distributed._tensor.device_mesh import _mesh_resources -from torch.distributed._tensor.placement_types import ( - _StridedShard, - DTensorSpec, - Placement, - TensorMeta, -) +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import _mesh_resources +from torch.distributed.tensor.placement_types import _StridedShard, Placement from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ( @@ -69,73 +65,77 @@ lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 -lib.define("set_(Tensor(a!) tensor, Tensor data) -> ()") +lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()") -@torch.library.impl(lib, "set_", "Meta") -@torch.library.impl(lib, "set_", "CUDA") -@torch.library.impl(lib, "set_", "CPU") -def set_(tensor, data): - tensor.set_(data) +@torch.library.impl(lib, "copy_", "Meta") +@torch.library.impl(lib, "copy_", "CUDA") +@torch.library.impl(lib, "copy_", "CPU") +def copy_(tensor, data): + tensor.copy_(data) """ -[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)] +[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_] -Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op +Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op (i.e. they show up as a mutation op in the middle of the AOT joint graph). Reason: Traceable FSDP2 compiled autograd BWD graph have the following traits: (1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). -(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param). +(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param). (3) They are both subclasses. The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). So this doesn't work at all for Traceable FSDP2. -The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops. +The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops. This avoids the problem above, because from AOTAutograd point-of-view there are no mutations that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) We can avoid this functionalization because: -(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created), -so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream. +(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created), +so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream. (2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay (since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). -Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places? -A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process. -Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner -(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to -make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph. -This requires a custom FX pass but we believe it's not hard to write and maintain. - -Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors? -A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use, -so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input. -This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original -nn.Parameter in order to see the result of .set_. +Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance? +A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops +for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0). + +TODO: +Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward, +it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are: + +Cons (of functionalizing those ops): +(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time +in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them +in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse +peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_: +https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089) + +Pros (of functionalizing): +(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations +mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more). +(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it. +But maybe there are few enough mutations induced by FSDP for this to matter. """ -@torch.library.impl(lib, "set_", "Functionalize") -def set__functionalize(tensor, data): +@torch.library.impl(lib, "copy_", "Functionalize") +def copy__functionalize(tensor, data): torch._sync(tensor) torch._sync(data) - # AOTDispatcher needs to know if any inputs had their storages mutated. - # (Why? It sometimes detaches inputs before sending them into the graph, - # when it sees that they do not need to have any gradients computed) - torch._functionalize_set_storage_changed(tensor) tensor_inner = torch._from_functional_tensor(tensor) data_inner = torch._from_functional_tensor(data) with torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ): - torch.ops.fsdp.set_.default(tensor_inner, data_inner) + torch.ops.fsdp.copy_.default(tensor_inner, data_inner) -torch.fx.node.has_side_effect(torch.ops.fsdp.set_.default) +torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) class ShardedState(Enum): @@ -479,7 +479,12 @@ def init_unsharded_param(self): with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( self._unsharded_param ): - torch.ops.fsdp.set_.default(self._unsharded_param, unsharded_param) + # NOTE: Under compile, if an unsharded param goes through + # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those + # resize_ and copy_ ops in a compiler graph pass + # `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance. + alloc_storage(self._unsharded_param) + torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param) else: self._unsharded_param = nn.Parameter( unsharded_param, requires_grad=self.sharded_param.requires_grad @@ -609,13 +614,26 @@ def alloc_all_gather_outputs(self) -> None: alloc_storage(tensor) def free_unsharded_param(self) -> None: - for tensor in itertools.chain( - self.all_gather_outputs, self._unsharded_inner_tensors - ): - free_storage(tensor) if ca.compiled_autograd_enabled: + """ + Assumptions under compile: + - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. + Instead, we resize `self._unsharded_param` storage size to full and then + explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param` + in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove + the resize_ and copy_ ops in a compiler graph pass to recover performance.) + - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT + graph inputs. They are created within the graph and is guaranteed to be freed + by the end of the graph. They don't leak outside of the graph. + """ + self._unsharded_param.untyped_storage().resize_(0) self.all_gather_outputs = [] self._unsharded_inner_tensors = [] + else: + for tensor in itertools.chain( + self.all_gather_outputs, self._unsharded_inner_tensors + ): + free_storage(tensor) @property def all_gather_inputs(self) -> List[torch.Tensor]: # 1D @@ -677,10 +695,11 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: if isinstance(grad, AsyncCollectiveTensor): grad = grad.wait() assert isinstance(grad, DTensor), f"{type(grad)}" - if any(pl.is_partial() for pl in grad.placements): - placements = [ - Replicate() if pl.is_partial() else pl for pl in grad.placements - ] + placements = self._tp_spec.placements + if placements != grad.placements: + assert len(self._tp_spec.placements) == len( + grad.placements + ), f"{self._tp_spec=} {grad.placements=}" grad = grad.redistribute(placements=placements) grad = grad._local_tensor return grad diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index 32f65998d6ce4c..e90863479a8d1e 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -12,7 +12,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle -from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_collectives import ( AllGatherResult, foreach_all_gather, @@ -73,9 +73,12 @@ def lazy_init(self): self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles def get_all_gather_streams( - self, training_state: TrainingState + self, async_op: bool, training_state: TrainingState ) -> Tuple[torch.cuda.Stream, torch.cuda.Stream]: - if training_state in (TrainingState.FORWARD, TrainingState.PRE_BACKWARD): + if not async_op and training_state in ( + TrainingState.FORWARD, + TrainingState.PRE_BACKWARD, + ): # Use separate streams for implicit prefetching return self.all_gather_copy_in_stream, self.all_gather_stream current_stream = torch.cuda.current_stream() @@ -127,6 +130,7 @@ def __init__( self.post_forward_mesh_info = post_forward_mesh_info self.device = device self.mp_policy = mp_policy + self.offload_policy = offload_policy self._training_state = TrainingState.IDLE # Group's sharded state always matches its parameters' sharded states self._sharded_state = ShardedState.SHARDED @@ -155,6 +159,10 @@ def __init__( # Optional custom reduce-scatter reduce op (e.g. to divide by a # factor other than the shard world size) self.reduce_scatter_reduce_op: Optional[dist.ReduceOp] = None + # `async_op` arg used for pre-forward/pre-backward unshard; can be + # overridden to only do explicit prefetching and avoid inter-stream + # fragmentation from using separate unshard streams + self.unshard_async_op: bool = False # - CUDA events for stream synchronization # Holds the all-gather output buffer, sync objects, and metadata @@ -201,18 +209,8 @@ def lazy_init(self): for fsdp_param in self.fsdp_params: fsdp_param.reset_sharded_param() self._reset_sharded_params = True - param_names_on_meta = [ - fsdp_param._param_fqn - for fsdp_param in self.fsdp_params - if fsdp_param.sharded_param.device.type == "meta" - ] - if param_names_on_meta: - raise RuntimeError( - "FSDP parameters should be materialized from meta device before training, " - f"but the following were still on meta device: {param_names_on_meta}\n" - "For example, call module.to_empty(device) to materialize to device and " - "call module.reset_parameters() on each module to initialize values." - ) + self._validate_no_meta_params() + self._validate_cpu_offload_params() # Initialize mixed precision attributes lazily in case the user changes # the parameter dtypes after construction time but before forward self._init_mp_dtypes() @@ -234,7 +232,7 @@ def unshard(self, async_op: bool = False): self.fsdp_params, self._all_gather_process_group, async_op, - *self.comm_ctx.get_all_gather_streams(self._training_state), + *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), self.device, ) @@ -249,6 +247,7 @@ def wait_for_unshard(self): """ if not self._all_gather_result: return # no preceding unshard + async_op = self._all_gather_result.all_gather_work is not None if self._training_state == TrainingState.FORWARD: # implicit prefetch if prev_all_gather_state := self.comm_ctx.all_gather_state: self._wait_all_gather_streams_on_event(prev_all_gather_state.event) @@ -264,7 +263,9 @@ def wait_for_unshard(self): self._to_unsharded() all_gather_copy_out_event = torch.cuda.Event() all_gather_copy_out_event.record() - if self._training_state == TrainingState.FORWARD: + if not async_op and self._training_state == TrainingState.FORWARD: + # Defer free to allow for overlap of this copy-out with next + # all-gather collective self.comm_ctx.all_gather_state = AllGatherState( self._all_gather_result, all_gather_copy_out_event ) @@ -297,7 +298,7 @@ def pre_forward( logger.debug("%s", self._with_fqn("FSDP::pre_forward")) with record_function(self._with_fqn("FSDP::pre_forward")): self._training_state = TrainingState.FORWARD - self.unshard() + self.unshard(self.unshard_async_op) self.wait_for_unshard() args, kwargs = self._register_post_backward_hook(args, kwargs) return args, kwargs @@ -325,9 +326,9 @@ def pre_backward(self, default_prefetch: bool, *unused: Any): logger.debug("%s", self._with_fqn("FSDP::pre_backward")) with record_function(self._with_fqn("FSDP::pre_backward")): self._training_state = TrainingState.PRE_BACKWARD - self.unshard() # no-op if prefetched + self.unshard(self.unshard_async_op) # no-op if prefetched self.wait_for_unshard() - if default_prefetch: + if default_prefetch and not ca.compiled_autograd_enabled: self._backward_prefetch() def post_backward(self, *unused: Any): @@ -430,7 +431,8 @@ def _prefetch_unshard( with record_function( f"FSDP::{pass_type}_prefetch for {target_fqn}" ), target_fsdp_param_group.use_training_state(training_state): - target_fsdp_param_group.unshard() + async_op = target_fsdp_param_group.unshard_async_op + target_fsdp_param_group.unshard(async_op) # Utilities # def _to_sharded(self): @@ -476,9 +478,9 @@ def use_training_state(self, training_state: TrainingState): def _register_post_backward_hook( self, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - # Compile relies on `root_post_backward_callback` to call each + # Traceable FSDP2 relies on `root_post_backward_callback` to call each # `FSDPParamGroup.post_backward` - if ca.compiled_autograd_enabled: + if (not torch._dynamo.config.skip_fsdp_hooks) or ca.compiled_autograd_enabled: return args, kwargs if not torch.is_grad_enabled(): return args, kwargs @@ -569,6 +571,36 @@ def _with_fqn(self, label: str) -> str: def __repr__(self): return f"FSDPParamGroup(fqn={self._module_fqn})" + def _validate_no_meta_params(self): + param_names_on_meta = [ + fsdp_param._param_fqn + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type == "meta" + ] + if param_names_on_meta: + raise RuntimeError( + "FSDP parameters should be materialized from meta device before training, " + f"but the following were still on meta device: {param_names_on_meta}\n" + "For example, call module.to_empty(device) to materialize to device and " + "call module.reset_parameters() on each module to initialize values." + ) + + def _validate_cpu_offload_params(self): + if not isinstance(self.offload_policy, CPUOffloadPolicy): + return + fsdp_params_not_on_cpu = [ + fsdp_param + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type != "cpu" + ] + if fsdp_params_not_on_cpu: + raise RuntimeError( + "FSDP parameters should be materialized on CPU when enabling CPU offloading. " + 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' + "Found following parameters on non-CPU device: " + f"{[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]}\n" + ) + def _get_param_module_infos( params: List[nn.Parameter], modules: Tuple[nn.Module, ...] @@ -602,14 +634,27 @@ def _get_param_module_infos( class RegisterPostBackwardFunction(torch.autograd.Function): + @staticmethod + def _assert_not_tracing_fsdp(): + if ca.compiled_autograd_enabled: + # TODO: Find a way to print the offending FSDP2 module. + msg = """\ +When Traceable FSDP2 is enabled, we rely on `root_post_backward_callback` to call +each `FSDPParamGroup.post_backward`, and we should not be calling into `RegisterPostBackwardFunction`. +If you are here, it means the forward part of this FSDP2 instance is not compiled, and you must also +compile the forward part if you want to use Traceable FSDP2.""" + torch._dynamo.comptime.comptime.print(msg) + raise RuntimeError(msg) + @staticmethod def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): # All tensors in `inputs` should require gradient + RegisterPostBackwardFunction._assert_not_tracing_fsdp() ctx.param_group = param_group - ctx.set_materialize_grads(False) return inputs @staticmethod def backward(ctx, *grads: torch.Tensor): + RegisterPostBackwardFunction._assert_not_tracing_fsdp() ctx.param_group.post_backward() return (None,) + grads diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index b06e1a6b777a1a..49c3da8fbfd02a 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from torch.distributed._composable import contract -from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor import DeviceMesh from torch.distributed.utils import _get_root_modules from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy @@ -358,6 +358,25 @@ def set_reduce_scatter_divide_factor(self, factor: float) -> None: reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor) fsdp_param_group.reduce_scatter_reduce_op = reduce_op + def _set_unshard_async_op(self, async_op: bool): + """ + Sets whether to use ``async_op=True`` or ``False`` for the pre-forward + and pre-backward unshard op. This defaults to ``False`` but can be set + to ``True`` with this method. + + Setting this to ``True`` allows the all-gather allocations to happen in + the default stream, avoiding inter-stream memory fragmentation. + However, you must use explicit prefetching (e.g. via :meth:`unshard`) + in forward to still get overlap, and the pre-all-gather ops like dtype + casting and copy-in will not overlap with compute. + """ + self_module = cast(nn.Module, self) + for module in self_module.modules(): + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.unshard_async_op = async_op + def _get_fsdp_state(self) -> FSDPState: if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: raise AssertionError(f"No FSDP state found on {self}") diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 2df9435c369a7a..4127885dccc1f6 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -97,7 +97,7 @@ def is_torchdynamo_compiling(): List[List[int]], dist.ProcessGroup, DeviceMesh, - Tuple["dist._tensor.DeviceMesh", int], + Tuple["dist.tensor.DeviceMesh", int], str, ] @@ -600,7 +600,7 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): elem = inner_tensors["elem"] return AsyncCollectiveTensor(elem) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"AsyncCollectiveTensor({self.trigger_wait()})" def trigger_wait(self): @@ -653,7 +653,7 @@ def wrap(e: torch.Tensor): return out - def numpy(self): + def numpy(self): # type: ignore[override] return self.wait().numpy() diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 68df582cd5145e..d50160ca8ecc31 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -1200,7 +1200,7 @@ def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: def __hash__(self): return id(self) - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"ShardedTensor({self._metadata})" @dataclass diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 340cb43e295e8d..be71f88cd52b8e 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -27,7 +27,8 @@ if dist.is_available() or TYPE_CHECKING: from torch.distributed import distributed_c10d from torch.distributed._shard.sharded_tensor import ShardedTensor - from torch.distributed._tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset def _identity_func( @@ -551,8 +552,17 @@ def _distribute_tensors( local_state = _local_state[0] full_tensor = _local_state[1] - local_state_dict[key] = distribute_tensor( - full_tensor, local_state.device_mesh, local_state.placements + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + local_state_dict[key] = DTensor.from_local( + local_tensor, local_state.device_mesh, local_state.placements ) @@ -627,16 +637,17 @@ def _distribute_state_dict( local_state_dict[key] = value.cpu() else: assert isinstance(value, torch.Tensor) - full_tensor = value.detach().to(device) local_state = local_state_dict.get(key, None) if local_state is None: continue elif isinstance(local_state, DTensor): - local_state_dict[key] = (local_state, full_tensor) + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) else: - local_state_dict[key] = full_tensor - - _distribute_tensors(local_state_dict, [key], device, pg) + local_state_dict[key] = value.detach().to(device) # These APIs are from torch.distributed.checkpoint. diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index c741ea5a0c46e9..40f9727015d764 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -1,58 +1,44 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates +""" +NOTICE: DTensor has moved to torch.distributed.tensor -import torch -import torch.distributed._tensor.ops as _ops # force import all built-in dtensor ops -from torch.distributed._tensor.api import ( +This file is a shim to redirect to the new location, and +we keep the old import path starts with `_tensor` for +backward compatibility. We will remove this folder once +we resolve all the BC issues. +""" +import sys +from importlib import import_module + + +submodules = [ + # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them + "_shards_wrapper", + "_utils", + "experimental", + "device_mesh", +] + +# Redirect imports +for submodule in submodules: + full_module_name = f"torch.distributed.tensor.{submodule}" + sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module( + full_module_name + ) + +from torch.distributed.tensor import ( # noqa: F401 + DeviceMesh, distribute_module, distribute_tensor, DTensor, empty, full, + init_device_mesh, ones, - rand, - randn, - zeros, -) -from torch.distributed._tensor.placement_types import ( Partial, Placement, + rand, + randn, Replicate, Shard, + zeros, ) -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.optim.optimizer import ( - _foreach_supported_types as _optim_foreach_supported_types, -) -from torch.utils._foreach_utils import ( - _foreach_supported_types as _util_foreach_supported_types, -) - - -# All public APIs from dtensor package -__all__ = [ - "DTensor", - "DeviceMesh", - "distribute_tensor", - "distribute_module", - "init_device_mesh,", - "Shard", - "Replicate", - "Partial", - "Placement", - "ones", - "empty", - "full", - "rand", - "randn", - "zeros", -] - - -# Append DTensor to the list of supported types for foreach implementation for optimizer -# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. -if DTensor not in _optim_foreach_supported_types: - _optim_foreach_supported_types.append(DTensor) - -if DTensor not in _util_foreach_supported_types: - _util_foreach_supported_types.append(DTensor) diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 7970cd7fb4bd51..9e5742156a86ca 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -1,1233 +1,9 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -import inspect -import warnings -from typing import Any, Callable, cast, Optional, Sequence, Tuple +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases -import torch -import torch.distributed._tensor._dispatch as op_dispatch -import torch.distributed._tensor.random as random -import torch.nn as nn -from torch.distributed._tensor._collective_utils import ( - check_tensor_meta, - mesh_broadcast, -) -from torch.distributed._tensor._redistribute import ( - Redistribute, - redistribute_local_tensor, -) -from torch.distributed._tensor._utils import ( - compute_global_tensor_info, - compute_local_shape, - normalize_to_torch_size, -) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Partial, - Placement, - Replicate, - Shard, - TensorMeta, -) -from torch.distributed._tensor.random import ( - is_rng_supported_mesh, - OffsetBasedRNGTracker, -) -from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +TODO: throw warnings when this module imported +""" - -__all__ = [ - "DTensor", - "distribute_tensor", - "distribute_module", - "ones", - "empty", - "full", - "rand", - "randn", - "zeros", -] - -aten = torch.ops.aten - - -# NOTE [Autograd interaction between torch.Tensor] -# -# The autograd functions defined below are being used by the public -# facing APIs (i.e. from_local, to_local) to ensure DTensor to work -# together with torch.Tensor within the autograd engine. This -# allows DTensor to only exist on part of the module hierarchy. -# -# As an example, we have the a module that consists of submodules -# A, B, and C, the execution flow would be like: -# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) -# -# Suppose I only want to make Module B be a sharded module with -# DTensor params, the following forward/backward should work: -# -# input(torch.Tensor) -> Module A -# -> DTensor input (from_local) -> Sharded Module B -> DTensor output -# -> torch.Tensor output (to_local) -> Module C -# -# So from_local/to_local must be Autograd functions. -# -class _ToTorchTensor(torch.autograd.Function): - @staticmethod - def forward( # type: ignore[override] - ctx, - input: "DTensor", - grad_placements: Optional[Sequence[Placement]], - ): - ctx.dtensor_spec = input._spec - ctx.grad_placements = grad_placements - local_tensor = input._local_tensor - - # We need to return a fresh Tensor object there as autograd metadata - # will be inplaced into it. So we don't want to pollute the Tensor - # object stored in the _local_tensor of this DTensor. - return local_tensor.view_as(local_tensor) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] - dtensor_spec = ctx.dtensor_spec - mesh = dtensor_spec.mesh - grad_placements = ctx.grad_placements - dtensor_meta = dtensor_spec.tensor_meta - - _, tensor_stride = compute_global_tensor_info( - grad_output, mesh, dtensor_spec.placements - ) - tensor_stride = tuple(tensor_stride) - grad_placements = grad_placements or dtensor_spec.placements - grad_spec = DTensorSpec( - mesh, - grad_placements, - tensor_meta=TensorMeta( - shape=dtensor_meta.shape, - stride=tensor_stride, - dtype=dtensor_meta.dtype, - ), - ) - - return ( - DTensor( - grad_output, - grad_spec, - requires_grad=grad_output.requires_grad, - ), - None, - ) - - -class _FromTorchTensor(torch.autograd.Function): - @staticmethod - def forward( # type: ignore[override] - ctx, # pyre-ignore[2]: Parameter must be annotated. - input: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], - run_check: bool, - shape: Optional[torch.Size] = None, - stride: Optional[Tuple[int, ...]] = None, - ) -> "DTensor": - ctx.previous_placement = placements - ctx.previous_device_mesh = device_mesh - - if shape and stride: - tensor_shape, tensor_stride = shape, stride - elif not shape and not stride: - # if it's not by default run_check, we assume user is certain that each - # rank has the same tensor shape, and we just use that to calculate the - # global shape - global_shape, global_stride = compute_global_tensor_info( - input, device_mesh, placements - ) - tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) - else: - raise RuntimeError( - f"Found shape:{shape}, stride:{stride}.", - "Please pass both shape and stride at the same time.", - ) - - if device_mesh.get_coordinate() is None: - # if the global rank is not participating in the device mesh, we - # simply set the local tensor to an empty tensor - input = input.new_empty(0, requires_grad=input.requires_grad) - elif run_check: - # TODO: support uneven sharding when global shape/stride not passed, by - # building the global TensorMeta during check_tensor_meta - check_shape_stride = not shape and not stride - check_tensor_meta(input, check_shape_stride=check_shape_stride) - # TODO: See if we need to make this run_check logic - # have a corresponding backward. - for idx, placement in enumerate(placements): - if placement.is_replicate(): - # broadcast rank 0 tensor to all ranks - # only broadcast if run_check is True - input = input.contiguous() - mesh_broadcast(input, device_mesh, mesh_dim=idx) - - dist_spec = DTensorSpec( - device_mesh, - placements, - tensor_meta=TensorMeta( - tensor_shape, - tensor_stride, - input.dtype, - ), - ) - - # We want a fresh Tensor object that shares memory with the input tensor - dist_tensor = DTensor( - input.view_as(input), - dist_spec, - # requires_grad of the dist tensor depends on if input - # requires_grad or not - requires_grad=input.requires_grad, - ) - return dist_tensor - - @staticmethod - def backward(ctx, grad_output: "DTensor"): # type: ignore[override] - previous_placement = ctx.previous_placement - previous_device_mesh = ctx.previous_device_mesh - - # reshard to the placement when creating DistributedTensor - # so that the gradient layout matches, and we could return - # local gradients directly - if grad_output.placements != previous_placement: - current_spec = grad_output._spec - target_spec = DTensorSpec( - previous_device_mesh, - previous_placement, - tensor_meta=grad_output._spec.tensor_meta, - ) - local_tensor = grad_output._local_tensor - output = redistribute_local_tensor( - local_tensor, current_spec, target_spec, is_backward=True - ) - # TODO: return the redistributed local tensor directly without - # differentiable backward. see if this make sense for all cases. - return output, None, None, None, None, None - - # TODO: backward is also differentiable now, add a test - # to test higher level gradients. - return grad_output.to_local(), None, None, None, None, None - - -class DTensor(torch.Tensor): - """ - ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like - abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding - layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: - - * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension - * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension - * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension - - When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue - communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the - placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. - - To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` - requires every Tensor argument of the operator be DTensor. - - """ - - _local_tensor: torch.Tensor - _spec: DTensorSpec - __slots__ = ["_local_tensor", "_spec"] - - # _op_dispatcher instance as a class attribute to handle runtime dispatching logic - _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() - - @staticmethod - @torch._disable_dynamo - def __new__( - cls, - local_tensor: torch.Tensor, - spec: DTensorSpec, - *, - requires_grad: bool, - ) -> "DTensor": - """ - Construct a DTensor from a local tensor, device mesh, and placement and - other tensor properties (i.e. shape, requires_grad, strides, etc). - Note: This is not a public API and it's only supposed to be used by the - operator implementations and internals. If you want to construct a - DTensor from a local tensor, consider using ``DTensor.from_local``, if - you want to construct a DTensor from a "global" tensor (where you - already have tensor initialized and want to shard this tensor), - consider using ``distribute_tensor``. - """ - if local_tensor.requires_grad and not requires_grad: - warnings.warn( - "To construct DTensor from torch.Tensor, it's recommended to " - "use local_tensor.detach() and make requires_grad consistent." - ) - - # new method instruct wrapper tensor from local_tensor and add - # placement spec, it does not do actual distribution - assert spec.tensor_meta is not None, "TensorMeta should not be None!" - r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - spec.tensor_meta.shape, - strides=spec.tensor_meta.stride, - dtype=local_tensor.dtype, - device=local_tensor.device, - layout=local_tensor.layout, - requires_grad=requires_grad, - ) - - r._spec = spec - r._local_tensor = local_tensor - return r - - # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. - # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): - # TODO: consider all_gather the local tensors for better debugging - return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" - - def __tensor_flatten__(self): - """ - protocol to inform how to flatten a DTensor to local tensor - for PT2 tracing - """ - return ["_local_tensor"], (self._spec, self.requires_grad) - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - assert ( - flatten_spec is not None - ), "Expecting spec to be not None from `__tensor_flatten__` return value!" - local_tensor = inner_tensors["_local_tensor"] - spec, requires_grad = flatten_spec - unflatten_tensor_meta = TensorMeta( - shape=outer_size, - stride=outer_stride, - dtype=spec.tensor_meta.dtype, - ) - unflatten_spec = DTensorSpec( - spec.mesh, - spec.placements, - tensor_meta=unflatten_tensor_meta, - ) - return DTensor( - local_tensor, - unflatten_spec, - requires_grad=requires_grad, - ) - - def __coerce_tangent_metadata__(self): - if not any(isinstance(p, Partial) for p in self.placements): - return self - placements = [ - Replicate() if isinstance(p, Partial) else p for p in self.placements - ] - return self.redistribute(device_mesh=self.device_mesh, placements=placements) - - def __coerce_same_metadata_as_tangent__(self, flatten_spec): - (spec, _) = flatten_spec # Result of tensor_flatten() - return self.redistribute( - device_mesh=self.device_mesh, - placements=spec.placements, - ) - - @classmethod - @torch._disable_dynamo - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - return DTensor._op_dispatcher.dispatch( - func, - args, - kwargs or {}, - ) - - @staticmethod - def from_local( - local_tensor: torch.Tensor, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - *, - run_check: bool = False, - shape: Optional[torch.Size] = None, - stride: Optional[Tuple[int, ...]] = None, - ) -> "DTensor": - """ - Create a :class:`DTensor` from a local torch.Tensor on each rank - according to the ``device_mesh`` and ``placements`` specified. - - Args: - local_tensor (torch.Tensor): local torch.Tensor on each rank. - device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the - tensor, if not specified, must be called under a DeviceMesh - context manager, default: None - placements (List[:class:`Placement`], optional): the placements that - describes how to place the local torch.Tensor on DeviceMesh, must - have the same number of elements as ``device_mesh.ndim``. - - Keyword args: - run_check (bool, optional): at a cost of extra communications, perform - sanity check across ranks to check each local tensor's meta information - to ensure correctness. If have :class:`Replicate` in ``placements``, the - data on first rank of the device mesh dimension will be broadcasted - to other ranks. default: False - shape (torch.Size, optional): A List of int which specifies the size of - DTensor which build on top of `local_tensor`. Note this needs to be - provided if the shape of ``local_tensor`` are different across the ranks. - If not provided, ``shape`` will be computed assuming the given distributed - tensor is evenly sharded across ranks. default: None - stride (tuple, optional): A List of int which specifies the stride of DTensor. - If not provided, ``stride`` will be computed assuming the given distributed - tensor is evenly sharded across ranks. default: None - - Returns: - A :class:`DTensor` object - - .. note:: When ``run_check=False``, it is the user's responsibility to ensure the - local tensor passed in is correct across ranks (i.e. the tensor is sharded for - the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). - If not, the behavior of the created DTensor is undefined. - - .. note:: ``from_local`` is differentiable, the `requires_grad` of the created - `DTensor` object will depend on if `local_tensor` requires_grad or not. - """ - # if same shape/dtype, no need to run_check, if not, must allgather - # the metadatas to check the size/dtype across ranks - # There should be no data communication unless there's replication - # strategy, where we broadcast the replication from the first rank - # in the mesh dimension - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - - # convert the local tensor to desired device base on device mesh's device_type - if device_type != local_tensor.device.type and not local_tensor.is_meta: - local_tensor = local_tensor.to(device_type) - - # set default placements to replicated if not specified - if placements is None: - placements = [Replicate() for _ in range(device_mesh.ndim)] - else: - placements = list(placements) - for idx, placement in enumerate(placements): - # normalize shard dim to be positive - if placement.is_shard(): - placement = cast(Shard, placement) - if placement.dim < 0: - placements[idx] = Shard(placement.dim + local_tensor.ndim) - - # `from_local` is differentiable, and the gradient of the dist tensor this function - # created should flow back the gradients to the local_tensor, so we call an autograd - # function to construct the dist tensor instead. - return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func - local_tensor, - device_mesh, - tuple(placements), - run_check, - shape, - stride, - ) - - def to_local( - self, *, grad_placements: Optional[Sequence[Placement]] = None - ) -> torch.Tensor: - """ - Get the local tensor of this DTensor on its current rank. For sharding it returns - a local shard of the logical tensor view, for replication it returns the replica on - its current rank. - - Keyword args: - grad_placements (List[:class:`Placement`], optional): the placements describes - the future layout of any gradient layout of the Tensor returned from this - function. - `to_local` converts DTensor to local tensor and the returned local tensor - might not be used as the original DTensor layout later in the code. This - argument is the hint that user can give to autograd in case the gradient - layout of the returned tensor does not match the original DTensor layout. - If not specified, we will assume the gradient layout remains the same - as the original DTensor and use that for gradient computation. - - Returns: - A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the - local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, - it means the local tensor is not ready yet (i.e. communication is not finished). In this - case, user needs to call ``wait`` to wait the local tensor to be ready. - - .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned - will depend on if the `DTensor` requires_grad or not. - """ - if not torch.is_grad_enabled(): - return self._local_tensor - - if grad_placements is not None and not isinstance(grad_placements, tuple): - grad_placements = tuple(grad_placements) - return _ToTorchTensor.apply( - self, grad_placements - ) # pyre-ignore[16]: autograd func - - def redistribute( - self, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - *, - async_op: bool = False, - ) -> "DTensor": - """ - ``redistribute`` performs necessary collective operations that redistribute the current - DTensor from its current placements to a new placements, or from is current DeviceMesh - to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by - specifying a Replicate placement for each dimension of the DeviceMesh. - - When redistributing from current to the new placements on one device mesh dimension, we - will perform the following operations including communication collective or local operation: - - 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` - 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` - 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) - 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` - 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` - - - ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors - that are created either on 1-D or N-D DeviceMesh. - - Args: - device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the - DTensor. If not specified, it would use the current DTensor's DeviceMesh. - default: None - placements (List[:class:`Placement`], optional): the new placements that - describes how to place the DTensor into the DeviceMesh, must - have the same number of elements as ``device_mesh.ndim``. - default: replicate on all mesh dimensions - - Keyword args: - async_op (bool, optional): whether to perform the DTensor redistribute operation - asynchronously or not. Default: False - - Returns: - A :class:`DTensor` object - - .. note:: ``redistribute`` is differentiable, which means user do not need to worry about - the backward formula of the redistribute operation. - - .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, - Please file an issue if you need to redistribute DTensor to different DeviceMesh. - """ - # NOTE: This redistribute API currently only supports out - # of place redistribution, i.e. it always create a new - # DTensor object and leave the original one unchanged. - - # if device_mesh is not specified, use the current device_mesh - device_mesh = device_mesh or self.device_mesh - # raise error if new placements not specified - if placements is None: - raise RuntimeError("placements is needed for redistribute!") - - placements = list(placements) - for i, placement in enumerate(placements): - if placement.is_partial(): - raise RuntimeError( - "Can not redistribute to Partial, redistributing to Partial is for internal use only!" - ) - elif isinstance(placement, Shard) and placement.dim < 0: - # normalize shard dim to be positive - placements[i] = Shard(placement.dim + self.ndim) - placements = tuple(placements) - - # pyre-fixme[16]: `Redistribute` has no attribute `apply`. - return Redistribute.apply(self, device_mesh, placements, async_op) - - def full_tensor( - self, *, grad_placements: Optional[Sequence[Placement]] = None - ) -> torch.Tensor: - """ - Return the full tensor of this DTensor. It will perform necessary collectives - to gather the local tensors from other ranks in its DeviceMesh and concatenate - them together. It's a syntatic sugar of the following code: - - ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` - - Keyword args: - grad_placements (List[:class:`Placement`], optional): the placements describes - the future layout of any gradient layout of the full Tensor returned from this - function. - `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor - might not be used as the original replicated DTensor layout later in the code. This - argument is the hint that user can give to autograd in case the gradient - layout of the returned tensor does not match the original replicated DTensor layout. - If not specified, we will assume the gradient layout of the full tensor be replicated. - - Returns: - A :class:`torch.Tensor` object that represents the full tensor of this DTensor. - - .. note:: ``full_tensor`` is differentiable. - """ - - redist_res = self.redistribute( - placements=[Replicate()] * self.device_mesh.ndim, async_op=False - ) - return _ToTorchTensor.apply(redist_res, grad_placements) - - @property - def device_mesh(self) -> DeviceMesh: - """ - The :class:`DeviceMesh` attribute that associates with this DTensor object. - - .. note:: ``device_mesh`` is a read-only property, it can not be set. - """ - return self._spec.mesh - - @property - def placements(self) -> Tuple[Placement, ...]: - """ - The placements attribute of this DTensor that describes the layout of this - DTensor on the its DeviceMesh. - - .. note:: ``placements`` is a read-only property, it can not be set. - """ - return self._spec.placements - - def __create_write_items__(self, fqn: str, object: Any): - from torch.distributed.checkpoint.planner_helpers import ( - _create_write_items_for_dtensor, - ) - - if hasattr(self._local_tensor, "__create_write_items__"): - return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] - elif isinstance(self._local_tensor, torch.Tensor): - return [_create_write_items_for_dtensor(fqn, object)] - else: - raise RuntimeError("Unsupported tensor type!") - - def __create_chunk_list__(self): - from torch.distributed.checkpoint.planner_helpers import ( - _create_chunk_from_dtensor, - ) - - if hasattr(self._local_tensor, "__create_chunk_list__"): - return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] - elif isinstance(self._local_tensor, torch.Tensor): - return [_create_chunk_from_dtensor(self)] - else: - raise RuntimeError("Unsupported tensor type!") - - def __get_tensor_shard__(self, index): - if hasattr(self._local_tensor, "__get_tensor_shard__"): - return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] - elif isinstance(self._local_tensor, torch.Tensor): - return self.to_local() - else: - raise RuntimeError("Unsupported tensor type!") - - -def distribute_tensor( - tensor: torch.Tensor, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according - to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the - same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use - the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve - the single-device semantic. If you want to construct a DTensor in the middle of the Autograd - computation, please use ``DTensor.from_local`` instead. - - Args: - tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you - want to shard a tensor on a dimension that is not evenly divisible by - the number of devices in that mesh dimension, we use ``torch.chunk`` - semantic to shard the tensor and scatter the shards. The uneven sharding - behavior is experimental and subject to change. - device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the - tensor, if not specified, must be called under a DeviceMesh context - manager, default: None - placements (List[:class:`Placement`], optional): the placements that - describes how to place the tensor on DeviceMesh, must have the same - number of elements as ``device_mesh.ndim``. If not specified, we will - by default replicate the tensor across the ``device_mesh`` from the - first rank of each dimension of the `device_mesh`. - - Returns: - A :class:`DTensor` or ``XLAShardedTensor`` object. - - .. note:: - When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` - return `XLAShardedTensor` instead. see [link](https://github.com/pytorch/pytorch/issues/92909) - for more details. The XLA integration is experimental and subject to change. - """ - - torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") - - # get default device mesh if there's nothing specified - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - if device_type == "xla": - try: - # call PyTorch/XLA SPMD for `xla` backend type device mesh. - # This returns XLAShardedTensor - from torch_xla.distributed.spmd import ( # type:ignore[import] - xla_distribute_tensor, - ) - - return xla_distribute_tensor( - tensor, device_mesh, placements - ) # type:ignore[return-value] - except ImportError as e: - msg = "To use DTensor API with xla, you must install the torch_xla package!" - raise ImportError(msg) from e - - # instantiate a RNG tracker if haven't. By default DTensor uses an - # OffsetBasedRNGTracker to perform random operators. - # TODO: the value assignment to global variable is not the ideal solution - # we can replace it in future. - if not random._rng_tracker and is_rng_supported_mesh(device_mesh): - random._rng_tracker = OffsetBasedRNGTracker(device_type) - - if not tensor.is_leaf: - raise RuntimeError( - "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" - ) - - # convert tensor to the corresponding device type if it's not in that device type - if device_type != tensor.device.type and not tensor.is_meta: - tensor = tensor.to(device_type) - - # set default placements to replicated if not specified - if placements is None: - placements = [Replicate() for _ in range(device_mesh.ndim)] - - if len(placements) != device_mesh.ndim: - raise ValueError( - f"`placements` must have the same length as `device_mesh.ndim`! " - f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." - ) - if isinstance(tensor, DTensor): - # if the tensor is already a DTensor, we need to check: - # 1. if the we can further shard this DTensor if the two device mesh belong to - # the same parenet mesh and further sharding is possible. - # 2. check if device mesh and placements are the same - if tensor.device_mesh != device_mesh: - raise ValueError( - f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " - f"to a different device mesh {device_mesh}." - ) - if tensor.placements != tuple(placements): - raise ValueError( - f"Cannot distribute a DTensor with placements {tensor.placements} " - f"to a different placements {placements}. do you want to call " - f"`redistribute` instead?" - ) - return tensor - - local_tensor = tensor.detach() - - # TODO(xilun): address sharding order - # distribute the tensor according to the placements. - placements = list(placements) - for idx, placement in enumerate(placements): - if placement.is_shard(): - placement = cast(Shard, placement) - if placement.dim < 0: - # normalize shard placement dim - placement = Shard(placement.dim + tensor.ndim) - placements[idx] = placement - local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) - elif placement.is_replicate(): - placement = cast(Replicate, placement) - local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) - else: - raise RuntimeError( - f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" - ) - placements = tuple(placements) - - assert local_tensor is not None, "distributing a tensor should not be None" - # detach the local tensor passed to DTensor since after the construction - # of DTensor, autograd would work on top of DTensor instead of local tensor - spec = DTensorSpec( - mesh=device_mesh, - placements=placements, - tensor_meta=TensorMeta( - shape=tensor.size(), - stride=tensor.stride(), - dtype=tensor.dtype, - ), - ) - return DTensor( - local_tensor.requires_grad_(tensor.requires_grad), - spec, - requires_grad=tensor.requires_grad, - ) - - -def distribute_module( - module: nn.Module, - device_mesh: Optional[DeviceMesh] = None, - partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, - input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, - output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, -) -> nn.Module: - """ - This function expose three functions to control the parameters/inputs/outputs of the module: - - 1. To perform sharding on the module before runtime execution by specifying the - ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` - parameters according to the `partition_fn` specified). - 2. To control the inputs or outputs of the module during runtime execution by - specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to - :class:`DTensor`, convert the output back to ``torch.Tensor``) - - Args: - module (:class:`nn.Module`): user module to be partitioned. - device_mesh (:class:`DeviceMesh`): the device mesh to place the module. - partition_fn (Callable): the function to partition parameters (i.e. shard certain - parameters across the ``device_mesh``). If ``partition_fn`` is not specified, - by default we replicate all module parameters of ``module`` across the mesh. - input_fn (Callable): specify the input distribution, i.e. could control how the - input of the module is sharded. ``input_fn`` will be installed as a module - ``forward_pre_hook`` (pre forward hook). - output_fn (Callable): specify the output distribution, i.e. could control how the - output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be - installed as a module ``forward_hook`` (post forward hook). - - Returns: - A module that contains parameters/buffers that are all ``DTensor`` s. - - .. note:: - When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` - return nn.Module with PyTorch/XLA SPMD annotated parameters. See [link](https://github.com/pytorch/pytorch/issues/92909) - for more details. The XLA integration is experimental and subject to change. - - """ - - torch._C._log_api_usage_once("torch.dtensor.distribute_module") - - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - if device_type == "xla": - try: - # This function annotates all module parameters for auto-partitioning with - # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters - # according to the `partition_fn` specified. - from torch_xla.distributed.spmd import ( # type:ignore[import] - xla_distribute_module, - ) - - return xla_distribute_module( - module, device_mesh, partition_fn, input_fn, output_fn - ) # type:ignore[return-value] - except ImportError as e: - msg = "To use DTensor API with xla, you must install the torch_xla package!" - raise ImportError(msg) from e - - def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: - # This function loop over the immediate module parameters and - # buffers, replicate all non DTensor params/buffers to DTensor - # parameters/buffers, if they have not been partitioned in the - # partition_fn, we can't easily use `module._apply` here - # because we don't know what happened inside partition_fn as - # user could do anything, i.e. install hooks, and we want to - # preserve those. - full_replicate = [Replicate()] * mesh.ndim - for key, param in m._parameters.items(): - if param is not None and not isinstance(param, DTensor): - m.register_parameter( - key, - nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), - ) - for key, buffer in m._buffers.items(): - if buffer is not None and not isinstance(buffer, DTensor): - m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) - - if partition_fn is None: - # if partition_fn not specified, we by default replicate - # all module params/buffers - for name, submod in module.named_modules(): - replicate_module_params_buffers(submod, device_mesh) - else: - # apply partition_fun to submodules - for name, submod in module.named_modules(): - partition_fn(name, submod, device_mesh) - replicate_module_params_buffers(submod, device_mesh) - - # register input_fn as module forward pre hook - if input_fn is not None: - # check the input_fn signature - num_args = len(inspect.signature(input_fn).parameters) - if num_args == 2: - # input_fn only takes in inputs and device mesh - warnings.warn( - "Deprecating input_fn that takes two arguments (inputs, device_mesh), " - "please use input_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, - stacklevel=2, - ) - module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] - elif num_args == 3: - # input_fn takes in module, inputs, device mesh - module.register_forward_pre_hook( - lambda mod, inputs: input_fn(mod, inputs, device_mesh) - ) - else: - raise ValueError( - f"input_fn should take in 3 arguments, but got {num_args} arguments!" - ) - # register output_fn as module forward hook - if output_fn is not None: - num_args = len(inspect.signature(output_fn).parameters) - if num_args == 2: - # output_fn only takes in outputs and device mesh - warnings.warn( - "Deprecating output_fn that takes two arguments (inputs, device_mesh), " - "please use output_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, - stacklevel=2, - ) - module.register_forward_hook( - lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] - ) - elif num_args == 3: - module.register_forward_hook( - lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) - ) - else: - raise ValueError( - f"output_fn should take in 3 arguments, but got {num_args} arguments!" - ) - - return module - - -# Below are tensor factory function APIs, which are used to create a DTensor directly. We need -# to make separate factory function APIs because tensor subclass could not override the tensor -# factory methods, and we need user to call the factory functions with user intended device_mesh -# and placements to create a proper DTensor. - - -def _dtensor_init_helper( # type: ignore[no-untyped-def] - init_op, - size: torch.Size, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, - **kwargs, -) -> DTensor: - # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta - - # if device_mesh is None, use the one from mesh resources - device_mesh = device_mesh or _mesh_resources.get_current_mesh() - kwargs["device"] = device_mesh.device_type - - # set default placements to replicated if not specified - placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) - - # check device_mesh againts placements - assert device_mesh.ndim == len( - placements - ), "mesh dimension does not match the length of placements" - - assert kwargs["layout"] == torch.strided, "layout value not supported!" - torch_stride = torch._prims_common.make_contiguous_strides_for(size) - - # get local tensor shape - local_shape = compute_local_shape(size, device_mesh, placements) - # initialize the local tensor - if init_op == torch.full: - fill_value = kwargs.pop("fill_value", 0) - local_tensor = init_op(local_shape, fill_value, **kwargs) - elif init_op == torch.rand or init_op == torch.randn: - # this tensor meta is not used except `shape` - dtype = kwargs.get("dtype", torch.get_default_dtype()) - - tensor_meta = TensorMeta(size, (0,), dtype) - spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) - - if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: - random._rng_tracker = random.OffsetBasedRNGTracker() - - assert random._rng_tracker is not None - with random._rng_tracker._distribute_region(spec): - local_tensor = init_op(local_shape, **kwargs) - else: - local_tensor = init_op(local_shape, **kwargs) - - spec = DTensorSpec( - device_mesh, - tuple(placements), - tensor_meta=TensorMeta( - size, - torch_stride, - local_tensor.dtype, - ), - ) - - return DTensor( - local_tensor, - spec, - requires_grad=kwargs["requires_grad"], - ) - - -def ones( # type: ignore[no-untyped-def] - *size, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined - by the variable argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.ones, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def empty( # type: ignore[no-untyped-def] - *size, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` - is defined by the variable argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ - layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.empty, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def full( # type: ignore[no-untyped-def] - size, - fill_value, - *, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - requires_grad: bool = False, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with ``fill_value``. The scalar value type should match - ``device_mesh.device_type``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - fill_value(Scalar): the value to fill the output tensor with. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.full, - torch_size, - fill_value=fill_value, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def rand( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with random numbers from a uniform distribution - on the interval ``[0, 1)``. The shape of the tensor is defined by the variable - argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.rand, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def randn( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with random numbers from a normal distribution - with mean 0 and variance 1. The shape of the tensor is defined by the variable - argument ``size``. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.randn, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) - - -def zeros( # type: ignore[no-untyped-def] - *size, - requires_grad: bool = False, - dtype: Optional[torch.dtype] = None, - layout: torch.layout = torch.strided, - device_mesh: Optional[DeviceMesh] = None, - placements: Optional[Sequence[Placement]] = None, -) -> DTensor: - """ - Returns a :class:`DTensor` filled with the scalar value 0. - - Args: - size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. - Can be a variable number of arguments or a collection like a list or tuple. - E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) - Keyword args: - requires_grad (bool, optional): If autograd should record operations on the - returned :class:`DTensor`. Default: ``False``. - dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. - Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). - layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. - Default: ``torch.strided``. - device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks - placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` - - Returns: - A :class:`DTensor` object on each rank - """ - torch_size = normalize_to_torch_size(size) - - return _dtensor_init_helper( - torch.zeros, - torch_size, - dtype=dtype, - layout=layout, - requires_grad=requires_grad, - device_mesh=device_mesh, - placements=placements, - ) +from torch.distributed.tensor._api import * # noqa: F401, F403 diff --git a/torch/distributed/_tensor/experimental/__init__.py b/torch/distributed/_tensor/experimental/__init__.py deleted file mode 100644 index 0b699eaf878d2c..00000000000000 --- a/torch/distributed/_tensor/experimental/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates -from contextlib import contextmanager - -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.experimental.func_map import local_map -from torch.distributed._tensor.experimental.register_sharding import register_sharding - - -__all__ = ["implicit_replication", "local_map", "register_sharding"] - - -@contextmanager -def implicit_replication(): - try: - DTensor._op_dispatcher._allow_implicit_replication = True - yield - finally: - DTensor._op_dispatcher._allow_implicit_replication = False diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 135dd8449a702a..6a4e70dbba4554 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -1,910 +1,10 @@ -# mypy: allow-untyped-defs -# Copyright (c) Meta Platforms, Inc. and affiliates +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases -from dataclasses import dataclass -from typing import Any, cast, List, NamedTuple, Optional, Tuple +TODO: throw warnings when this module imported +""" -import torch -import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._collective_utils import ( - fill_empty_tensor_to_shards, - mesh_broadcast, - mesh_scatter, - pad_tensor, - shard_dim_alltoall, - unpad_tensor, -) -from torch.distributed.device_mesh import DeviceMesh - - -__all__ = ["Placement", "Shard", "Replicate", "Partial", "DTensorSpec", "TensorMeta"] - - -class Placement: - # base class Placement type - - # convenient utils to check for placement types - def is_shard(self, dim: Optional[int] = None) -> bool: - is_shard_instance = isinstance(self, Shard) - if dim is not None and is_shard_instance: - return cast(Shard, self).dim == dim - else: - return is_shard_instance - - def is_replicate(self) -> bool: - return isinstance(self, Replicate) - - def is_partial(self) -> bool: - return isinstance(self, Partial) - - -@dataclass(frozen=True) -class Shard(Placement): - """ - The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension - ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the - DeviceMesh dimension only holds a shard/piece of the global Tensor. The - ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the - last few shards on the DeviceMesh dimension might be empty when the tensor dimension - is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be - used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) - - Args: - dim (int): The tensor dimension that describes the DTensor is sharded over its - corresponding DeviceMesh dimension. - - .. warning:: sharding on a tensor dimension where the tensor dimension size is not - evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. - """ - - dim: int - - def _split_tensor( - self, - tensor: torch.Tensor, - num_chunks: int, - *, - with_padding: bool = True, - contiguous: bool = True, - ) -> Tuple[List[torch.Tensor], List[int]]: - """ - This function uses torch.chunk to split a tensor into num_chunks shards along - the Shard placement dimension, and return a list of shards with their pad sizes. - - Keyword args: - with_padding (bool, optional): when True, we pad the tensor on the last - few ranks before calling the collectives (i.e. scatter/all_gather, etc.). - This is because collectives usually require equal size tensor inputs - """ - assert ( - self.dim <= tensor.ndim - ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" - - # chunk tensor over dimension `dim` into n slices - tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) - num_empty_tensors = num_chunks - len(tensor_list) - - # if no need to have padding or tensor dim size is evenly sharded already - # we can return early. - if not with_padding or tensor.size(self.dim) % num_chunks == 0: - if contiguous: - tensor_list = [t.contiguous() for t in tensor_list] - return ( - fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors), - [], - ) - - # compute the chunk size inline with ``torch.chunk`` to calculate padding - full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks - - # Compute chunk size for each chunk for ``self.dim`` - chunk_sizes = [ - tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 - for idx in range(num_chunks) - ] - # Compute pad size on each chunk - pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] - - # Reuse tensor to fill empty chunk with empty tensor - tensor_list = fill_empty_tensor_to_shards( - tensor_list, self.dim, num_empty_tensors - ) - shard_list = [] - for shard, pad_size in zip(tensor_list, pad_sizes): - # Fill the empty tensor with zeroes with padding. - if with_padding and pad_size > 0: - shard = pad_tensor(shard, self.dim, pad_size) - shard = shard.contiguous() if contiguous else shard - shard_list.append(shard) - return shard_list, pad_sizes - - @staticmethod - def _local_shard_size_on_dim( - size_on_dim: int, - num_chunks: int, - rank: int, - return_offset: bool = False, - ) -> Tuple[int, int]: - """ - returns the local shard size and offset on a given tensor dim - """ - # Compute the chunk size inline with ``torch.chunk`` - if size_on_dim % num_chunks == 0: - full_chunk_size = size_on_dim // num_chunks - return full_chunk_size, full_chunk_size * rank if return_offset else -1 - - # uneven sharding case - full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks - shard_starting_idx = full_chunk_size * rank - - if size_on_dim < shard_starting_idx: - return 0, size_on_dim if return_offset else -1 - else: - local_shard_size = ( - min(size_on_dim, shard_starting_idx + full_chunk_size) - - shard_starting_idx - ) - return local_shard_size, shard_starting_idx if return_offset else -1 - - def _shard_tensor( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - """ - shard and scatter a tensor on a mesh dimension (use coordinate - 0 on the mesh dimension as source of truth) - """ - my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(mesh_dim=mesh_dim) - - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor - return tensor.new_empty(0, requires_grad=tensor.requires_grad) - - scatter_list, pad_sizes = self._split_tensor( - tensor, num_chunks, with_padding=True, contiguous=True - ) - - mesh_dim_local_rank = my_coordinate[mesh_dim] - output = torch.empty_like(scatter_list[mesh_dim_local_rank]) - mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) - - # Only unpad if the local_tensor was padded on the dimension. - if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0: - output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) - return output - - def _reduce_shard_tensor( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - reduce_op: str, - mesh_dim: int, - ) -> torch.Tensor: - """ - reduce and scatter a tensor on a mesh dimension - """ - my_coordinate = mesh.get_coordinate() - num_chunks = mesh.size(mesh_dim=mesh_dim) - - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor - return tensor - - is_padded = tensor.size(self.dim) % num_chunks != 0 - if is_padded: - scattered_list, pad_sizes = self._split_tensor( - tensor, num_chunks, with_padding=True, contiguous=True - ) - tensor = torch.cat(scattered_list, dim=self.dim) - elif not tensor.is_contiguous(): - tensor = tensor.contiguous() - - output = funcol.reduce_scatter_tensor( - tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) - ) - - if is_padded: - output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] - return output - - def _to_replicate_tensor( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - current_logical_shape: List[int], - ) -> torch.Tensor: - """ - This function all_gather all shards and return a tensor that - is replicated on the previously sharded mesh dimension - """ - num_chunks = mesh.size(mesh_dim=mesh_dim) - # check if it's uneven, so we need to pad input tensor before all_gather - local_shape = list(local_tensor.size()) - - logical_dim_size = current_logical_shape[self.dim] - is_padded = logical_dim_size % num_chunks != 0 - - if is_padded: - full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks - pad_size = full_chunk_size - local_shape[self.dim] - local_tensor = pad_tensor(local_tensor, self.dim, pad_size) - - if not local_tensor.is_contiguous(): - local_tensor = local_tensor.contiguous() - - result = funcol.all_gather_tensor( - local_tensor, - gather_dim=self.dim, - group=(mesh, mesh_dim), - ) - if is_padded: - unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] - result = unpad_tensor(result, self.dim, unpad_size) - return result - - def _replicate_to_shard( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_index: int, - ) -> torch.Tensor: - """ - transform from replicated tensor to a sharded tensor on - the current rank, which would perform a local chunk - """ - num_chunks = mesh.size(mesh_dim=mesh_dim) - shards, _ = self._split_tensor( - local_tensor, - num_chunks, - with_padding=False, - contiguous=False, - ) - return shards[shard_index].clone() - - def _to_new_shard_dim( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - current_logical_shape: List[int], - new_shard_dim: int, - ) -> torch.Tensor: - """ - transform from existing sharded tensor to a new sharded tensor on - that shard on a new dimension, which performs an alltoall - """ - my_coordinate = mesh.get_coordinate() - if my_coordinate is None: - # if rank is not part of mesh, we simply return local_tensor, - # which should be an empty tensor - return local_tensor - - num_chunks = mesh.size(mesh_dim=mesh_dim) - - old_dim_logical_size = current_logical_shape[self.dim] - new_dim_logical_size = current_logical_shape[new_shard_dim] - old_dim_padding = old_dim_logical_size % num_chunks != 0 - new_dim_padding = new_dim_logical_size % num_chunks != 0 - if old_dim_padding: - old_dim_full_chunk_size = ( - old_dim_logical_size + num_chunks - 1 - ) // num_chunks - old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) - local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) - if new_dim_padding: - new_dim_full_chunk_size = ( - new_dim_logical_size + num_chunks - 1 - ) // num_chunks - new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( - new_shard_dim - ) - local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) - - if not local_tensor.is_contiguous(): - local_tensor = local_tensor.contiguous() - - new_tensor = shard_dim_alltoall( - local_tensor, self.dim, new_shard_dim, mesh, mesh_dim - ) - - if old_dim_padding: - old_dim_unpad_size = ( - old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] - ) - new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] - - if new_dim_padding: - local_shard_size_on_new_dim = self._local_shard_size_on_dim( - new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] - )[0] - new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] - new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] - - return new_tensor - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Shard): - return False - return self.dim == other.dim - - def __hash__(self) -> int: - return hash(self.dim) - - def __repr__(self) -> str: - """ - machine readable representation of the Shard placement - """ - return f"Shard(dim={self.dim})" - - def __str__(self) -> str: - """human readable representation of the Shard placement""" - return f"S({self.dim})" - - -# kw_only is only available in python >= 3.10 -kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} - - -@dataclass(frozen=True, **kw_only_dataclass) -class _StridedShard(Shard): - """ - _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor - is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. - We call this right-to-left sharding which is the opposite of the default - left-to-right sharding. See the example below: - tensor shape: [8, 8] - mesh: [[0, 1], [2, 3]], names=("dp", "tp") - placements: [Shard(0), Shard(0)] - - The default sharding behavior shards the tensor on "dp" mesh dimension first then - "tp" dimension. The sharding result will be: - Rank | Mesh Coordinate | Shard Index - ------------------------------------------------ - 0 | (0, 0) | 0 (row 0-1) - 1 | (0, 1) | 1 (row 2-3) - 2 | (1, 0) | 2 (row 4-5) - 3 | (1, 1) | 3 (row 6-7) - - While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on - "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the - result: - Rank | Mesh Coordinate | Shard Index - ------------------------------------------------ - 0 | (0, 0) | 0 (row 0-1) - 1 | (0, 1) | 2 (row 4-5) - 2 | (1, 0) | 1 (row 2-3) - 3 | (1, 1) | 3 (row 6-7) - - The consequence is, any attempt to redistribute this DTensor to a full replica will - produce a wrong result because the shard-to-replicate redistribution always happens - right-to-left, regardless it's left-to-right sharding or right-to-left. To address - this, we use _StridedShard placement to make this right-to-left sharding compatible - with our left-to-right convention on both tensor distribution and redistribution. - - Now with _StridedShard, the right-to-left sharding above can be represented as: - tensor shape: [8, 8] - mesh: [[0, 1], [2, 3]], names=("dp", "tp") - placements: [_StridedShard(0, split_factor=2), Shard(0)] - - And a left-to-right processing of `placements` will produce the same result, which is - different from using the `Shard` placement: - Rank | Mesh Coordinate | Shard Index - ------------------------------------------------ - 0 | (0, 0) | 0 (row 0-1) - 1 | (0, 1) | 2 (row 4-5) - 2 | (1, 0) | 1 (row 2-3) - 3 | (1, 1) | 3 (row 6-7) - - The argument `split_factor` is the number of existing shards over the tensor sharding - dimension before processing the _StridedShard placement, as if the sharding happened - right-to-left. In the example above, the tensor should first be sharded on the "tp" - dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the - `split_factor` of the _StridedShard placement on "dp" dim is 2. - - TODO: strided sharding needs to work fine with uneven sharding. Now it forbids - resharding if the tensor is unevenly sharded. - TODO: we should remove _StridedShard placement once we can unify it with Shard - """ - - split_factor: int - - def __eq__(self, other: object) -> bool: - if isinstance(other, _StridedShard): - return self.dim == other.dim and self.split_factor == other.split_factor - elif isinstance(other, Shard): - # TODO: this is to avoid extra all-gather in dtensor op dispatch - # note that sharding prop would not produce _StridedShard and an - # placement inequality would introduce an all-gather for resharding - return self.dim == other.dim - return False - - def __hash__(self) -> int: - return hash((self.dim, self.split_factor)) - - def __repr__(self) -> str: - """ - machine readable representation of the _StridedShard placement - """ - return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" - - def __str__(self) -> str: - """human readable representation of the _StridedShard placement""" - return f"_S({self.dim}, {self.split_factor})" - - def _split_tensor( - self, - tensor: torch.Tensor, - num_chunks: int, - *, - with_padding: bool = True, - contiguous: bool = True, - ) -> Tuple[List[torch.Tensor], List[int]]: - """ - TODO: currently _StridedShard does not support padding - """ - assert ( - self.dim <= tensor.ndim - ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" - - total_split = num_chunks * self.split_factor - assert tensor.size(self.dim) % total_split == 0, ( - "_StridedShard currently only allows even sharding but got tensor size" - f" {tensor.size(self.dim)} on dim {self.dim} and total split" - f" {total_split}={num_chunks} * {self.split_factor}" - ) - - group_size = self.split_factor - total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) - tensor_list = [ - torch.cat( - [ - total_split_tensor_list[i + j * num_chunks] # stride is num_chunks - for j in range(group_size) - ], - dim=self.dim, - ) - for i in range(num_chunks) - ] - - if contiguous: - tensor_list = [t.contiguous() for t in tensor_list] - - return tensor_list, [] - - def _to_replicate_tensor( - self, - local_tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - current_logical_shape: List[int], - ) -> torch.Tensor: - """ - Note: currently _StridedShard does not support padding - """ - num_chunks = mesh.size(mesh_dim=mesh_dim) - total_split = num_chunks * self.split_factor - # NOTE: we require Strided Sharding to be even for now - assert current_logical_shape[self.dim] % total_split == 0, ( - "_StridedShard requires even sharding but got tensor size " - f"{current_logical_shape[self.dim]} on dim {self.dim} and " - f"total split {total_split}=num_chunks {num_chunks} " - f"* split_factor {self.split_factor}" - ) - - result = funcol.all_gather_tensor( - local_tensor, - gather_dim=self.dim, - group=(mesh, mesh_dim), - ) - if isinstance(result, funcol.AsyncCollectiveTensor): - result = result.wait() - - tensor_shard_list = torch.chunk(result, total_split, dim=self.dim) - # rearrange the order - new_tensor_shard_list = [] - for idx in range(len(tensor_shard_list)): - # the shard split of index `idx` is assigned a new index within - # _StridedShard._split_tensor: - # the original tensor was split into `total_split` chunks, - # all chunks with the same `idx % num_chunks` are merged into one - # new shard and placed on mesh's local rank `idx % num_chunks` - idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks - new_tensor_shard_list.append(tensor_shard_list[idx_after_split]) - - return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous() - - -@dataclass(frozen=True) -class Replicate(Placement): - """ - The ``Replicate()`` placement describes the DTensor replicating on a corresponding - ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a - replica of the global Tensor. The ``Replicate`` placement can be used by all - DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) - """ - - def __eq__(self, other: object) -> bool: - return isinstance(other, Replicate) - - def __hash__(self) -> int: - # every replicate placement is the same - return -1 - - def __repr__(self) -> str: - """ - machine readable representation of the Replicate placement - """ - return "Replicate()" - - def __str__(self) -> str: - """ - human readable representation of the Replicate placement - """ - return "R" - - def _replicate_tensor( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - """ - Replicate (broadcast) a torch.Tensor on a mesh dimension (use - the first coordinate on the mesh dimension as source of truth) - """ - my_coordinate = mesh.get_coordinate() - if my_coordinate is None: - # if rank is not part of mesh, we simply return an empty tensor - return tensor.new_empty(0, requires_grad=tensor.requires_grad) - - tensor = tensor.contiguous() - mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim) - return tensor - - -@dataclass(frozen=True) -class Partial(Placement): - """ - The ``Partial(reduce_op)`` placement describes the DTensor that is pending - reduction on a specified ``DeviceMesh`` dimension, where each rank on the - DeviceMesh dimension holds the partial value of the global Tensor. User can - redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` - placement on the specified ``DeviceMesh`` dimension using ``redistribute``, - which would trigger necessary communication operations under the hood (i.e. - ``allreduce``, ``reduce_scatter``). - - Args: - reduce_op (str, optional): The reduction op to be used for the partial DTensor - to produce Replicated/Sharded DTensor. Only element-wise reduction operations - are supported, including: "sum", "avg", "product", "max", "min", default: "sum". - - .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, - and can only be used by the ``DTensor.from_local`` API. - """ - - reduce_op: str = "sum" - - def _reduce_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # Partial placement contract #1: - # _reduce_value: reduce the value of the tensor on the mesh dimension - return funcol.all_reduce( - tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) - ) - - def _reduce_shard_value( - self, - tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int, - shard_spec: Placement, - ) -> torch.Tensor: - # Partial placement contract #2: - # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension - shard_spec = cast(Shard, shard_spec) - return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) - - def _partition_value( - self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int - ) -> torch.Tensor: - # Partial placement contract #3: - # _partition_value: partition the value of a replicated tensor on the mesh dimension - - # _partition_value is the conjugate operation of _reduce_value - # - i.e. _partition_value on a sum reduce op is just a divison operation - # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation - # TODO: if the reduce_op is min/max, etc. the _partition_value should be a - # different operation - assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" - num_chunks = mesh.size(mesh_dim=mesh_dim) - return tensor / num_chunks - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Partial): - return False - return self.reduce_op == other.reduce_op - - def __hash__(self) -> int: - return 1 + hash(self.reduce_op) - - def __repr__(self) -> str: - """ - machine readable representation of the Partial placement - """ - return f"Partial({self.reduce_op})" - - def __str__(self) -> str: - """ - human readable representation of the Partial placement - """ - return "P" - - -# We keep the old _Partial name for a while for BC reason -_Partial = Partial - - -class TensorMeta(NamedTuple): - # simple named tuple to represent tensor metadata - # intentionally to stay simple only for sharding - # propagation purposes. - shape: torch.Size - stride: Tuple[int, ...] - dtype: torch.dtype - - -# used internally to propagate the placements -@dataclass -class DTensorSpec: - mesh: DeviceMesh - placements: Tuple[Placement, ...] - - # tensor meta will only be set during sharding propagation - tensor_meta: Optional[TensorMeta] = None - - def __post_init__(self): - if not isinstance(self.placements, tuple): - self.placements = tuple(self.placements) - self._hash: Optional[int] = None - - def __setattr__(self, attr: str, value: Any): - super().__setattr__(attr, value) - # Make sure to recompute the hash in case any of the hashed attributes - # change (though we do not expect `mesh` or `placements` to change) - if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): - self._hash = None - - def _hash_impl(self) -> int: - # hashing and equality check for DTensorSpec are used to cache the sharding - # propagation results. We only need to consider the mesh, placements, shape - # dtype and stride. - # Caveat: we need to keep this in mind and sync hash and eq if we add more - # fields to them. - if self.tensor_meta is not None: - return hash( - ( - self.mesh, - self.placements, - self.tensor_meta.shape, - self.tensor_meta.stride, - self.tensor_meta.dtype, - ) - ) - return hash((self.mesh, self.placements)) - - def __hash__(self) -> int: - # We lazily cache the spec to avoid recomputing the hash upon each - # use, where we make sure to update the hash when the `tensor_meta` - # changes by overriding `__setattr__`. This must be lazy so that Dynamo - # does not try to hash non-singleton `SymInt`s for the stride. - if self._hash is None: - self._hash = self._hash_impl() - return self._hash - - def __eq__(self, __o: object) -> bool: - if not ( - isinstance(__o, DTensorSpec) - and self.mesh == __o.mesh - and self.placements == __o.placements - ): - return False - if self.tensor_meta is None or __o.tensor_meta is None: - return self.tensor_meta == __o.tensor_meta - - return ( - self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr] - and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr] - and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr] - ) - - def __str__(self) -> str: - """ - human readable representation of the DTensorSpec - """ - if len(self.placements) == 1: - placement_str = str(self.placements[0]) - else: - placement_str = str(self.placements) - - if self.tensor_meta is not None: - tensor_shape = str(tuple(self.tensor_meta.shape)) - else: - tensor_shape = "unknown shape" - - return f"Spec({placement_str} on {tensor_shape})" - - @property - def shape(self) -> torch.Size: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return self.tensor_meta.shape - - @property - def stride(self) -> Tuple[int, ...]: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return self.tensor_meta.stride - - @property - def ndim(self) -> int: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return len(self.tensor_meta.shape) - - @property - def num_shards(self) -> int: - num_shards = 1 - for i, placement in enumerate(self.placements): - if placement.is_shard(): - num_shards *= self.mesh.size(i) - return num_shards - - @property - def device_mesh(self) -> DeviceMesh: - # simple aliasing for the mesh field, make some - # checks that mixes DTensor/DTensorSpec easier - return self.mesh - - @property - def dim_map(self) -> List[int]: - """ - dim_map is a property we derive from `placements` of - the distributed tensor. It simply return a list of ints - where dim_map[i] denotes the sharding mapping to the mesh - dimension, and len(dim_map) == dist_tensor.ndim - dim_map[i] = -1: means tensor dim i replicate on mesh - dim_map[i] = j: means tensor dim i shard on mesh dim j - - For example, we have a dist tensor that have the shape of - [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: - [Shard(1)], the dim_map of this placement would be: - [-1, 0, -1]. This representation is pretty helpful during - sharding propagation where we could know exactly each - tensor dimension is sharded or not. - - Note that if placements contains `_Partial`, we have to - explicitly deal with it, so that when we create a DTensorSpec - with dim_map, we could properly record the pending sums. - """ - # dims mapping of dist tensor sharding - # return size of tensor ndim, -1 represent replicate - # and int >=0 represent shard on that device mesh dim - r = [-1] * self.ndim - for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - if r[shard_dim] > -1: - raise ValueError( - f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," - " DTensor operator implementation does not support things like hybrid" - " sharding strategies yet (i.e. [Shard(0), Shard(0)])" - ) - r[shard_dim] = i - return r - - @property - def num_shards_map(self) -> List[int]: - """ - dim_map is a property we derive from `placements` of - the distributed tensor. Unlike `dim_map`, `num_shards_map` - denotes how many shards each tensor dim has. Like `dim_map`: - len(num_shards_map) == dist_tensor.ndim - num_shards_map[i] = 1: means tensor dim i is not sharded - num_shards_map[i] = j: means tensor dim i has j shards in total - - For example, we have a dist tensor of shape [18, 20, 30], - a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements - ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor - would be: [4, 2, 1]. - """ - r = [1] * self.ndim - for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - r[shard_dim] *= self.mesh.size(i) - - return r - - @property - def sums(self) -> List[int]: - """ - sums is a property we derive from `placements` of the - distributed tensor. It simply return a list of ints where - sums[i] denotes the pending sum (partial) on mesh dim i - """ - return [ - idx - for idx, placement in enumerate(self.placements) - if placement.is_partial() - ] - - @classmethod - def from_dim_map( - cls, - mesh: DeviceMesh, - dim_map: List[int], - sums: List[int], - tensor_meta: Optional[TensorMeta] = None, - ) -> "DTensorSpec": - """ - Construct a DTensorSpec from dim_map list and pending sum. - - Args: - mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec - dim_map (List[int]): a list of integer that represents sharding on each - tensor dimension, see `dim_map` property doc for details - sums (List[int]): a list of integer that represents the dist tensor have - pending sum on which device mesh dimension. - tensor meta (TensorMeta): DTensor metadata - - Return: - a class:`DTensorSpec` object - """ - # by default replicate on device mesh dims - placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] - - # find all mesh dims that need pending reductions - for s in sums: - placements[s] = Partial() - - for i, m in enumerate(dim_map): - if m >= 0: - placement = placements[m] - if placement.is_shard(): - placement = cast(Shard, placement) - raise RuntimeError( - f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" - ) - elif placement.is_partial(): - raise RuntimeError( - f"DeviceMesh dimension {m} cannot be both shard and partial!" - ) - placements[m] = Shard(i) - - return cls(mesh, tuple(placements), tensor_meta=tensor_meta) - - def is_replicated(self): - """ - return True if the current DTensorSpec replicates on all mesh dims (devices) - """ - return all(placement.is_replicate() for placement in self.placements) - - def is_sharded(self): - """ - return True if the current DTensorSpec is sharded on any mesh dims (devices) - """ - return any(placement.is_shard() for placement in self.placements) - - def shallow_copy_with_tensor_meta( - self, tensor_meta: Optional[TensorMeta] - ) -> "DTensorSpec": - """ - Shallow copy the DTensorSpec with a new tensor_meta. - """ - assert tensor_meta is not None, "shallow copy with no tensor_meta!" - return DTensorSpec( - self.mesh, - self.placements, - tensor_meta=tensor_meta, - ) +from torch.distributed.tensor._dtensor_spec import * # noqa: F401, F403 +from torch.distributed.tensor.placement_types import * # noqa: F401, F403 diff --git a/torch/distributed/_tools/__init__.py b/torch/distributed/_tools/__init__.py index 8acd9cce2f9033..cd57eedba37517 100644 --- a/torch/distributed/_tools/__init__.py +++ b/torch/distributed/_tools/__init__.py @@ -2,3 +2,4 @@ from .mem_tracker import MemTracker from .memory_tracker import MemoryTracker from .mod_tracker import ModTracker +from .runtime_estimator import RuntimeEstimator diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py new file mode 100644 index 00000000000000..87f4d3f36b60e0 --- /dev/null +++ b/torch/distributed/_tools/runtime_estimator.py @@ -0,0 +1,527 @@ +# Owner(s): ["module: unknown"] +import math +import os +from collections import defaultdict +from typing import Any, Callable, Dict, List, Set, Tuple +from typing_extensions import Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import active_fake_mode +from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.mod_tracker import ModTracker +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.flop_counter import flop_registry + + +aten = torch.ops.aten + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + +# No fall-back kernel needed/exists for view ops +_VIEW_OPS = { + aten.lift_fresh, + aten.t, + aten.transpose, + aten.view, + aten.detach, + aten._unsafe_view, + aten.split, + aten.adjoint, + aten.as_strided, + aten.diagonal, + aten.expand, + aten.expand_as, + aten.movedim, + aten.permute, + aten.select, + aten.squeeze, + aten.mT, + aten.mH, + aten.real, + aten.imag, + aten.view_as, + aten.unflatten, + aten.unfold, + aten.unbind, + aten.unsqueeze, + aten.vsplit, + aten.hsplit, + aten.split_with_sizes, + aten.swapaxes, + aten.swapdims, + aten.chunk, +} +# We can ignore benchmarking tensor create ops +_CREATE_OPS = { + aten.randint, + aten.randn, + aten.rand, + aten.randn_like, + aten.rand_like, + aten.randint_like, + aten.arange, + aten.ones_like, + aten.zeros_like, +} + +_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS + +__all__ = ["RuntimeEstimator"] + + +class RuntimeEstimator(TorchDispatchMode): + """ + Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager + runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and + roofline cost modeling (`operator-level-cost-model`). + For modules executed under this context manager, it agggregates the forward and backward operation runtimes + and also records their execution orders. + + Attributes: + mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary + is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the + operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. + mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. + mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. + mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. + mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. + total_runtime (float): The total estimated runtime in milliseconds. + + Note: + 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in + isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. + 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support + them in future PRs. + 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will + support this in future PRs. + + Example usage: + + .. code-block:: python + + runtime_estimator = RuntimeEstimator() + with FakeTensorMode(): + module = ... + optimizer = ... + inp = ... + with runtime_estimator(estimate_mode_type="operator-level-cost-model"): + loss = module(inp) + loss.backward() + optimizer.step() + optimizer.zero_grad() + runtime_estimator.display_modulewise_stats() + """ + + _float_types: Set[torch.dtype] = { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + } + _no_fallback_kernel: Set[torch._ops._OpNamespace] = set() + fake_mode: FakeTensorMode + + def __init__(self) -> None: + super().__init__() + self._estimate: Callable + self._estimate_mode_type: str + self._mod_tracker = ModTracker() + self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict( + lambda: defaultdict(lambda: 0.0) + ) + self.mod_fw_pre_order: List[str] = [] + self.mod_bw_pre_order: List[str] = [] + self.mod_fw_post_order: List[str] = [] + self.mod_bw_post_order: List[str] = [] + self.total_runtime: float = 0.0 + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # NB: returns fake tensors + @classmethod + def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] + cls, + func, + args, + kwargs, + orig_not_implemented_exception, + ): + """ + Runs and benchmarks a fallback kernel for a given function. + + Args: + func (Callable): The function to benchmark. + args (Tuple): The arguments to pass to the function. + kwargs (Dict[str, Any]): The keyword arguments to pass to the function. + orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel + is not implemented. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + + inp_impls = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): # type: ignore[no-untyped-def] + if cls.fake_mode.is_our_fake(e): + if e.dtype in cls._float_types: + out = torch.rand_like(e, device=e.fake_device) + else: + out = torch.ones_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + r = func(*args, **kwargs) + warmup_iters, actual_iters = 2, 3 + for _ in range(warmup_iters): + func(*args, **kwargs) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record(torch.cuda.current_stream()) + for _ in range(actual_iters): + func(*args, **kwargs) + end_event.record(torch.cuda.current_stream()) + torch.cuda.synchronize() + cuda_time = start_event.elapsed_time(end_event) + mean_op_time = cuda_time / actual_iters + + storages = set() + + for e in flat_args: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): # type: ignore[no-untyped-def] + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return cls.fake_mode.fake_tensor_converter.from_real_tensor( + cls.fake_mode, e + ) + else: + return e + + return (pytree.tree_map(map_out, r), mean_op_time) + + @classmethod + def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using benchmarking. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert isinstance( + cls.fake_mode, FakeTensorMode + ), "Initialize/Assign FakeTensorMode before using this function" + mean_op_time = 0.0 + if func._overloadpacket not in _VIEW_OPS: + try: + res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( + func, + args, + kwargs, + NotImplementedError, + ) + return (res, mean_op_time) + except NotImplementedError: + cls._no_fallback_kernel.add(func._overloadpacket) + res = func(*args, **kwargs or {}) + return (res, mean_op_time) + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert ( + torch.cuda.is_available() + ), "Roofline estimation needs to access CUDA capabilities to make estimations" + + def get_num_bytes(t: torch.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (torch.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = ( + math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + ) + return mem_consumed + + def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert ( + len(out_dtypes) == 1 + ), f"Only support single out dtype got {out_dtypes} for {func_packet}" + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[torch.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) + for t in flat_args_kwargs + if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time + + # Roofline Cost Model Explanation + + # The roofline cost model estimates the execution time of an operator based on + # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). + + # Variables: + # - pi: Maximum empirical FLOPs/sec of the device + # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device + # - I: Arithmetic intensity of the operator (FLOPs/bytes) + # - op_flops: FLOPs required by the operator + # - op_bytes: Bytes transferred to and from DRAM for the operator + + # Calculation Steps: + # 1. Calculate arithmetic intensity: I = op_flops / op_bytes + # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) + # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec + # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) + # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) + + # Simplified Formulas: + # - compute_time = op_flops / pi + # - transfer_time = op_bytes / beta + # - estimated_op_time = max(compute_time, transfer_time) + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + def display_modulewise_stats(self, depth: int = 2) -> None: + """ + Displays module-wise statistics collected by ``RuntimeEstimator``. + + Prints the pre-forward and pre-backward execution orders. + Displays the module-wise forward and backward runtimes in milliseconds. + + Args: + depth (int): The maximum depth of module hierarchy to display (default to 2). + """ + print("Pre-Forward Execution Order: ") + for mod_fqn in self.mod_fw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + print("Pre-Backward Execution Order: ") + for mod_fqn in self.mod_bw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + for mod_fqn, runtimes in self.mod_runtimes.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print( + f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses + # TODO: @sanketpurandare: Add logic for incorporating communication time + res, op_time = self._estimate(func, args, kwargs) + for par in self._mod_tracker.parents: + if self._mod_tracker.is_bw: + self.mod_runtimes[par]["bw"] += op_time + else: + self.mod_runtimes[par]["fw"] += op_time + self.total_runtime += op_time + return res + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + RuntimeEstimator: The runtime estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + self._estimate_mode_type = estimate_mode_type + return self + + def __enter__(self) -> Self: + fake_mode = active_fake_mode() + assert isinstance( + fake_mode, FakeTensorMode + ), "No FakeTensorMode found, designed to used under FakeTensorMode" + RuntimeEstimator.fake_mode = fake_mode + self.total_runtime = 0.0 + self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) + self.mod_fw_pre_order.clear() + self.mod_bw_pre_order.clear() + self.mod_fw_post_order.clear() + self.mod_bw_post_order.clear() + self._mod_tracker.register_user_hooks( + pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + ) + self._mod_tracker.__enter__() + super().__enter__() + return self + + def __exit__(self, *args: Any) -> None: + print( + f"Estimated ({self._estimate_mode_type})" + f"total_time: {self.total_runtime:.3f} ms" + ) + if len(self._no_fallback_kernel) > 0: + print("no_fallback_kernel: ", list(self._no_fallback_kernel)) + super().__exit__(*args) + self._mod_tracker.clear_user_hooks() + self._mod_tracker.__exit__() diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 0acafd6868d3bb..872ad0e2a76731 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -136,7 +136,7 @@ def _low_precision_hook( prec: torch.dtype, state: LowPrecisionState, grad: torch.Tensor, - output: torch.Tensor, + output: Optional[torch.Tensor], ): if grad.dtype != prec: grad.data = grad.data.to(prec) diff --git a/torch/distributed/checkpoint/_nested_dict.py b/torch/distributed/checkpoint/_nested_dict.py index 3347ea8bc432ae..b846a1d47d8434 100644 --- a/torch/distributed/checkpoint/_nested_dict.py +++ b/torch/distributed/checkpoint/_nested_dict.py @@ -3,7 +3,14 @@ from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE -from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict +from . import _version +from ._traverse import ( + OBJ_PATH, + set_element, + STATE_DICT_ITEM, + traverse_state_dict, + traverse_state_dict_v_2_3, +) """ @@ -40,7 +47,16 @@ def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: flattened[new_fqn] = value mappings[new_fqn] = path - traverse_state_dict(state_dict, flat_copy) + # We started to flatten dictionary since v2.4. But in order to not break + # the checkpoints that were saved before v2.4, we need to keep the old + # traversal so that we can reconstruct those checkpoints. + use_v_2_3 = ( + _version._derived_version is not None and _version._derived_version == "2_3" + ) + if use_v_2_3: + traverse_state_dict_v_2_3(state_dict, flat_copy) + else: + traverse_state_dict(state_dict, flat_copy) return flattened, mappings diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 8bcb832c71980f..2c79013abeb269 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -14,8 +14,8 @@ import torch from torch.distributed._shard.sharded_tensor.api import ShardedTensor -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.distributed.tensor import DTensor PATH_ITEM = Union[str, int] @@ -40,14 +40,11 @@ def traverse_state_dict( ) -> None: """ Invoke ``visitor`` for each value recursively in ``state_dict``. - Mapping, list, and tuple will be flattened and other value types are treated - as the terminal values and will invoke ``visitor``. - Mapping is treated as non terminal node and will be flattened. - List and tuple, on the other hand, will not be flattened unless containing other - mapping containers or tensors. + Mapping will be traversed and ``visitor`` will be applied to the leaf elements. + ``visitor`` will only be applied to elements in a list or a tuple, if the + container contains tensors or mappings. """ - # a value is terminal if it has no other containers values inside it def _is_terminal(value: STATE_DICT_ITEM) -> bool: values: Collection[STATE_DICT_ITEM] if isinstance(value, Mapping): @@ -78,6 +75,49 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: _traverse_obj((str(key),), value) +def traverse_state_dict_v_2_3( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``torch.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + def set_element( root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM ) -> None: diff --git a/torch/distributed/checkpoint/_version.py b/torch/distributed/checkpoint/_version.py new file mode 100644 index 00000000000000..b3065bdfd6a2c1 --- /dev/null +++ b/torch/distributed/checkpoint/_version.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Optional + + +_derived_version: Optional[str] = None diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 462808217ede63..67f3c73bca51da 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -11,7 +11,6 @@ import torch from torch.distributed._shard._utils import narrow_tensor_by_index -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans from torch.distributed.checkpoint._nested_dict import ( FLATTEN_MAPPING, @@ -45,6 +44,9 @@ _init_state_dict, ) from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.tensor import DTensor + +from . import _version logger: logging.Logger = logging.getLogger(__name__) @@ -195,6 +197,39 @@ def set_up_planner( def create_local_plan(self) -> LoadPlan: assert self.metadata is not None + if self.flatten_state_dict: + # To support checkpoints that are saved before v2.4, we have to + # differentiate if the missing keys are due to old checkpoints. + # The contracts are: + # 1. There are 3 cases when we found a missing key. + # 1.1 Actual missing key, but allow_partial_load is False + # 1.2 Actual missing key, but allow_partial load is True + # 1.3 Old checkpoint, but allow_partial_load is False + # 1.4 Old checkpoint, but allow_partial_load is True + # 2. If we found a missing key, we first convert the keys back to + # the key format of v2.3 + # 3. If the previous missing keys are in the v2.3 keys, we assume + # this is a old checkpoint. + # 4. Pass the state_dict to `create_default_local_load_plan()`, + # which has the logic to check missing for allow_partial_load. + # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to + # `create_default_local_load_plan()`. The logic here is to determine + # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). + current_keys = set(self.state_dict.keys()) + load_keys = set(self.metadata.state_dict_metadata.keys()) + missing_keys = load_keys - current_keys + if missing_keys: + _version._derived_version = "2_3" + old_state_dict, old_mappings = flatten_state_dict( + self.original_state_dict + ) + old_keys = set(old_state_dict.keys()) + if old_keys & missing_keys: + self.state_dict, self.mappings = old_state_dict, old_mappings + # _derived_version is only used by flatten_state_dict now. + # Set it back to None so that later we can save to a new version. + _version._derived_version = None + return create_default_local_load_plan( self.state_dict, self.metadata, not self.allow_partial_load ) diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index 589f9b93544289..0f8e392f4e9c36 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -11,12 +11,12 @@ import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor.device_mesh import init_device_mesh DEVICE = "cuda" diff --git a/torch/distributed/checkpoint/examples/stateful_example.py b/torch/distributed/checkpoint/examples/stateful_example.py index 9624229b53f79b..1e5e2f7c967b63 100644 --- a/torch/distributed/checkpoint/examples/stateful_example.py +++ b/torch/distributed/checkpoint/examples/stateful_example.py @@ -12,11 +12,11 @@ import torch.distributed.checkpoint as dcp import torch.multiprocessing as mp import torch.nn as nn -from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.checkpoint.state_dict import ( _patch_model_state_dict, _patch_optimizer_state_dict, ) +from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 3faee96b238dc7..079608c7230629 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -747,15 +747,19 @@ def __init__( N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. """ - super().__init__( + _FileSystemWriter.__init__( + self, path=path, single_file_per_rank=single_file_per_rank, sync_files=sync_files, thread_count=thread_count, per_thread_copy_ahead=per_thread_copy_ahead, - cache_staged_state_dict=cache_staged_state_dict, overwrite=overwrite, ) + BlockingAsyncStager.__init__( + self, + cache_staged_state_dict=cache_staged_state_dict, + ) def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: """Override of AsyncStager.stage""" diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index f37fae5c0a87e7..5ad170768db929 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -12,7 +12,6 @@ ) from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec -from torch.distributed._tensor import DTensor from torch.distributed.checkpoint._nested_dict import unflatten_state_dict from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.metadata import ( @@ -39,6 +38,7 @@ from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DTensor STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index 17ab4b027d3a37..d715c651e024ba 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -7,8 +7,8 @@ from torch._utils import _get_device_module from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor import ShardedTensor -from torch.distributed._tensor import DTensor -from torch.distributed._tensor._utils import compute_local_shape_and_global_offset +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset from .metadata import ( BytesStorageMetadata, diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index f7a7ea799ef03e..ef90b32eb2042f 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -32,7 +32,6 @@ _offload_state_dict_to_cpu, _unflatten_state_dict, ) -from torch.distributed._tensor import DTensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, ) @@ -50,6 +49,7 @@ _get_module_fsdp_state_if_fully_sharded_module, FSDP_WRAPPED_MODULE, ) +from torch.distributed.tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils._pytree import tree_map_only @@ -552,6 +552,7 @@ def _load_model_state_dict( state_dict[fqn_with_prefix] = state_dict.pop(fqn) local_state_dict[fqn_with_prefix] = value + assign = False if info.broadcast_from_rank0 or info.full_state_dict: device = None for key, value in local_state_dict.items(): @@ -563,7 +564,7 @@ def _load_model_state_dict( assert device is not None if device == torch.device("meta"): device = dist.distributed_c10d._get_pg_default_device() - model.to_empty(device=device) + assign = True if info.broadcast_from_rank0: _broadcast_state_dict( state_dict, local_state_dict, device=device, strict=info.strict @@ -577,7 +578,7 @@ def _load_model_state_dict( return cast( _IncompatibleKeys, _state_dict_fn(model, "load_state_dict")( - state_dict=state_dict, strict=info.strict + state_dict=state_dict, strict=info.strict, assign=assign ), ) @@ -590,17 +591,17 @@ def _init_optim_state(optim: torch.optim.Optimizer) -> None: # The optimizer state is initialized. return + # There are some stateless optimizers like SGD. These optimizer will + # not return in the above condition. So if gradients exist, we should also + # return. If gradients do not exist, the following initialization should + # not disturb SGD because the gradients and lr are both zero. for param_group in optim.param_groups: for param in param_group[_PARAMS]: if param.grad is not None: - raise RuntimeError( - "state_dict can only be used if the optimizer " - "states are initialized (usually after one step() with " - "gradients) or gradients are None. For the later case, " - "state_dict will fake the gradients as zero " - "to initialize the optimizer states. However, the " - "gradients are not None." - ) + return + + for param_group in optim.param_groups: + for param in param_group[_PARAMS]: if param.requires_grad: param.grad = torch.zeros_like(param) diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 5bc171a00dd2d9..cb90dd11912159 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -44,7 +44,7 @@ def _all_gather_keys( ) -> List[Any]: """Gathers all keys, and returns them sorted.""" keys = list(local_dict.keys()) - gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item] + gathered_keys: List[List[Any]] = [None] * dist.get_world_size(group) # type: ignore[list-item] dist.all_gather_object(gathered_keys, keys, group=group) return sorted(set(itertools.chain.from_iterable(gathered_keys))) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 14aa38be42b34c..d08bcfefc50e5f 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -36,6 +36,7 @@ def _init_device_mesh_stub(): else: + from torch._C._distributed_c10d import Backend as C10dBackend from torch.distributed.distributed_c10d import ( _find_pg_by_ranks_and_tag, _get_default_group, @@ -66,7 +67,7 @@ def __init__(self) -> None: self.mesh_stack: List[DeviceMesh] = [] self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} self.mesh_dim_group_options: Dict[ - int, Tuple[str, Optional[ProcessGroup.Options]] + int, Tuple[str, Optional[C10dBackend.Options]] ] = {} self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. @@ -279,7 +280,7 @@ def _set_mesh_dim_group_options( self, dim: int, backend: str, - pg_options: Optional[ProcessGroup.Options] = None, + pg_options: Optional[C10dBackend.Options] = None, ) -> None: self.mesh_dim_group_options[dim] = (backend, pg_options) @@ -335,6 +336,35 @@ def _get_slice_mesh_dims( return slice_mesh_dims + def _get_all_submeshes( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> List["DeviceMesh"]: + """ + Return all the submeshes of a given mesh dimension of the device mesh. + """ + mesh_dim = self.get_mesh_dim_by_name(device_mesh, mesh_dim_name) + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( + -1, device_mesh.mesh.size(mesh_dim) + ) + + cur_rank = device_mesh.get_rank() + res_submeshes = [] + for mesh_1d in pg_ranks_by_dim: + submesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_backend=False, + ) + submesh._dim_group_infos = ( + [device_mesh._dim_group_infos[mesh_dim]] + if cur_rank in mesh_1d + else [] + ) + res_submeshes.append(submesh) + + return res_submeshes + _mesh_resources: _MeshEnv = _MeshEnv() def _get_device_handle(device_type: str = "cuda"): @@ -440,7 +470,7 @@ def _get_or_create_default_group(self): world_size = get_world_size() if self.mesh.numel() > world_size: raise RuntimeError( - f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!" + f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!" ) device_handle = _get_device_handle(self.device_type) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9fa3224873c9fc..2d9357bbd15a44 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -280,6 +280,7 @@ class Backend(str): NCCL: ProcessGroup.BackendType.NCCL, XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, + MPI: ProcessGroup.BackendType.MPI, } def __new__(cls, name: str): @@ -1554,7 +1555,7 @@ def init_process_group( backend, store, group_name, - pg_options=pg_options, + backend_options=pg_options, timeout=timeout, device_id=device_id, group_desc="default_pg", @@ -1651,7 +1652,7 @@ def _new_process_group_helper( backend, store, group_name, - pg_options=None, + backend_options=None, timeout=None, pg_tag=None, device_id=None, @@ -1730,11 +1731,17 @@ def _new_process_group_helper( return GroupMember.NON_GROUP_MEMBER, None prefix_store = PrefixStore(f"{group_name}/", store) - base_pg_options = ProcessGroup.Options(backend=str(backend)) - base_pg_options._timeout = timeout + # The backend for PG will be set later based on what's inside BackendConfig + # and timeout are set in each backend's option. pg: ProcessGroup = ProcessGroup( - prefix_store, group_rank, group_size, base_pg_options + prefix_store, + group_rank, + group_size, ) + # Set the default backend when only single backend is passed in. + if "," not in str(backend) and ":" not in str(backend): + assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" + pg._set_default_backend(Backend.backend_type_map[backend]) if device_id: pg.bound_device_id = device_id backend_config = BackendConfig(backend) @@ -1761,8 +1768,8 @@ def _new_process_group_helper( backend_prefix_store, backend_class.rank(), backend_class.size(), - base_pg_options, ) + pg._set_default_backend(backend_type) elif backend_str == Backend.GLOO: # TODO: remove this check after lazy initialization is supported # if pg_options is not None: @@ -1774,27 +1781,30 @@ def _new_process_group_helper( elif backend_str == Backend.NCCL: if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") - if pg_options is not None: + if backend_options is not None: assert isinstance( - pg_options, ProcessGroupNCCL.Options - ), "Expected pg_options argument to be of type ProcessGroupNCCL.Options" - if pg_options._timeout != timeout: + backend_options, ProcessGroupNCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + if backend_options._timeout != timeout: warnings.warn( - "pg_options._timeout was specified, " + "backend_options._timeout was specified, " "but timeout kwarg has a default value that will always override it. " ) else: - # default pg_options for NCCL - pg_options = ProcessGroupNCCL.Options() - pg_options.is_high_priority_stream = False - pg_options._timeout = timeout + # default backend_options for NCCL + backend_options = ProcessGroupNCCL.Options() + backend_options.is_high_priority_stream = False + backend_options._timeout = timeout + if split_from: - pg_options.split_from = split_from - pg_options.split_color = _process_group_color(global_ranks_in_group) - pg_options.global_ranks_in_group = global_ranks_in_group - pg_options.group_name = group_name + backend_options.split_from = split_from + backend_options.split_color = _process_group_color( + global_ranks_in_group + ) + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name backend_class = ProcessGroupNCCL( - backend_prefix_store, group_rank, group_size, pg_options + backend_prefix_store, group_rank, group_size, backend_options ) backend_type = ProcessGroup.BackendType.NCCL elif backend_str == Backend.UCC and is_ucc_available(): @@ -1840,7 +1850,7 @@ def _new_process_group_helper( dist_backend_opts.group_id = group_name dist_backend_opts.global_ranks_in_group = global_ranks_in_group - backend_class = creator_fn(dist_backend_opts, pg_options) + backend_class = creator_fn(dist_backend_opts, backend_options) # Set sequence numbers for gloo and nccl backends. if backend_str == Backend.GLOO: @@ -4475,12 +4485,15 @@ def split_group( global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] prefix_store = PrefixStore(f"{group_name}/", default_store) - base_pg_options = ProcessGroup.Options(backend=str(backend)) - base_pg_options._timeout = timeout + # We register the backend after initializing and timeout is set in pg_options. pg: ProcessGroup = ProcessGroup( - prefix_store, group_rank, len(my_group), base_pg_options + prefix_store, + group_rank, + len(my_group), ) + backend_type = ProcessGroup.BackendType.NCCL pg.bound_device_id = device_id + pg._set_default_backend(backend_type) pg_options._timeout = timeout pg_options.split_from = parent_backend @@ -4490,7 +4503,6 @@ def split_group( backend_class = ProcessGroupNCCL( prefix_store, group_rank, len(my_group), pg_options ) - backend_type = ProcessGroup.BackendType.NCCL backend_class._set_sequence_number_for_group() pg._register_backend(torch.device("cuda"), backend_type, backend_class) @@ -4613,7 +4625,7 @@ def _new_group_with_tag( ranks=None, timeout=None, backend=None, - pg_options=None, + backend_options=None, pg_tag=None, use_local_synchronization=False, group_desc=None, @@ -4688,7 +4700,7 @@ def _new_group_with_tag( backend, default_store, group_name, - pg_options=pg_options, + backend_options=backend_options, timeout=timeout, pg_tag=pg_tag, device_id=device_id, diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 0b3a9a2ce29087..5befbffe49573d 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -16,6 +16,7 @@ import subprocess import sys import tempfile +import threading import time from abc import ABC, abstractmethod from contextlib import nullcontext @@ -470,11 +471,17 @@ def __init__( def start(self) -> None: """Start processes using parameters defined in the constructor.""" - signal.signal(signal.SIGTERM, _terminate_process_handler) - signal.signal(signal.SIGINT, _terminate_process_handler) - if not IS_WINDOWS: - signal.signal(signal.SIGHUP, _terminate_process_handler) - signal.signal(signal.SIGQUIT, _terminate_process_handler) + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGTERM, _terminate_process_handler) + signal.signal(signal.SIGINT, _terminate_process_handler) + if not IS_WINDOWS: + signal.signal(signal.SIGHUP, _terminate_process_handler) + signal.signal(signal.SIGQUIT, _terminate_process_handler) + else: + logger.warning( + "Failed to register signal handlers since torchelastic is running on a child thread. " + "This could lead to orphaned worker processes if the torchrun is terminated." + ) self._start() self._stdout_tail.start() self._stderr_tail.start() diff --git a/torch/distributed/elastic/rendezvous/__init__.py b/torch/distributed/elastic/rendezvous/__init__.py index 62a31adab27b01..22ec0c9a0f6758 100644 --- a/torch/distributed/elastic/rendezvous/__init__.py +++ b/torch/distributed/elastic/rendezvous/__init__.py @@ -143,10 +143,11 @@ class that implements the rendezvous mechanism described above. It is a backend- RendezvousStoreInfo, RendezvousTimeoutError, ) -from .registry import _register_default_handlers +from .registry import _register_default_handlers, _register_out_of_tree_handlers _register_default_handlers() +_register_out_of_tree_handlers() __all__ = [ diff --git a/torch/distributed/elastic/rendezvous/api.py b/torch/distributed/elastic/rendezvous/api.py index 9cde6758981abf..cbfc5532c76a6a 100644 --- a/torch/distributed/elastic/rendezvous/api.py +++ b/torch/distributed/elastic/rendezvous/api.py @@ -68,14 +68,21 @@ class RendezvousStoreInfo: master_port: int @staticmethod - def build(rank: int, store: Store) -> "RendezvousStoreInfo": + def build( + rank: int, store: Store, local_addr: Optional[str] + ) -> "RendezvousStoreInfo": """Factory method, finds unused new port on rank0 host and addr/port info with all ranks. If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor. + + Args: + rank: rank of the current node + store: store to use for rendezvous + local_addr: address of the current node, if not provided will be resolved from hostname """ # TODO swap to collectives comms API if rank == 0: - addr = socket.getfqdn() + addr = local_addr or socket.getfqdn() port = _get_free_port() store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type] store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type] diff --git a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py index 26c3153d9785b6..427a53bc3276dd 100644 --- a/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py @@ -28,6 +28,9 @@ logger = logging.getLogger(__name__) +# default port for the TCP store +DEFAULT_PORT = 29400 + class C10dRendezvousBackend(RendezvousBackend): """Represents a C10d-backed rendezvous backend. @@ -132,7 +135,7 @@ def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]: def _create_tcp_store(params: RendezvousParameters) -> TCPStore: - host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400) + host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT) cfg_is_host = params.get_as_bool("is_host") # If the user has explicitly specified whether our process should host the @@ -144,8 +147,6 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore: else: is_host = _matches_machine_hostname(host) - use_libuv = params.get_as_bool("use_libuv", False) - # The timeout read_timeout = cast(int, params.get_as_int("read_timeout", 60)) if read_timeout <= 0: @@ -161,7 +162,6 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore: is_master=is_server, multi_tenant=True, timeout=timedelta(seconds=read_timeout), - use_libuv=use_libuv, ) if is_server: diff --git a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py index 31627cf0a0b27c..c8e294604501db 100644 --- a/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py @@ -1115,7 +1115,6 @@ def _create_tcp_store_server(self, bootstrap_store_info) -> dist.TCPStore: bootstrap_store_info.master_port, is_master=True, multi_tenant=True, - use_libuv=True, ) @property @@ -1179,7 +1178,9 @@ def next_rendezvous(self) -> RendezvousInfo: # opt-out option of TCP store sharing if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": - bootstrap_store_info = RendezvousStoreInfo.build(rank, store) + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) return RendezvousInfo( store, rank, @@ -1199,7 +1200,9 @@ def next_rendezvous(self) -> RendezvousInfo: else: # If the store is not type of TCPStore start TCPStore server, which requries # bootstrapping info across ranks - self._bootstrap_store_info = RendezvousStoreInfo.build(rank, store) + self._bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._this_node.addr + ) if rank == 0: self._shared_tcp_store_server = self._create_tcp_store_server( self._bootstrap_store_info diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index fe6170ede0159e..f0aa3c8d8887d8 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -149,8 +149,15 @@ class EtcdRendezvousHandler(RendezvousHandler): +--------------------------------------------+--------------------------+ """ - def __init__(self, rdzv_impl): + def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): + """ + Args: + rdzv_impl: the implementation of the rendezvous + local_addr: the local address of the current node + """ + self._rdzv_impl = rdzv_impl + self._local_addr = local_addr def __del__(self): # TODO: look into using weakref here instead. @@ -165,7 +172,9 @@ def next_rendezvous(self): logger.info("Creating EtcdStore as the c10d::Store implementation") store = self._rdzv_impl.setup_kv_store(rdzv_version) - bootstrap_store_info = RendezvousStoreInfo.build(rank, store) + bootstrap_store_info = RendezvousStoreInfo.build( + rank, store, local_addr=self._local_addr + ) return RendezvousInfo(store, rank, world_size, bootstrap_store_info) def is_closed(self): @@ -1062,4 +1071,7 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT ), ) - return EtcdRendezvousHandler(rdzv_impl=rdzv) + return EtcdRendezvousHandler( + rdzv_impl=rdzv, + local_addr=params.local_addr, + ) diff --git a/torch/distributed/elastic/rendezvous/registry.py b/torch/distributed/elastic/rendezvous/registry.py index 1a91d0a8ff7946..d038ab95eabae6 100644 --- a/torch/distributed/elastic/rendezvous/registry.py +++ b/torch/distributed/elastic/rendezvous/registry.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import sys + from .api import ( rendezvous_handler_registry as handler_registry, RendezvousHandler, @@ -12,6 +15,13 @@ from .dynamic_rendezvous import create_handler +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points + +log = logging.getLogger(__name__) + __all__ = ["get_rendezvous_handler"] @@ -50,6 +60,21 @@ def _register_default_handlers() -> None: handler_registry.register("static", _create_static_handler) +def _register_out_of_tree_handlers() -> None: + discovered_handler_generators = entry_points(group="torchrun.handlers") + + for handler_generator in discovered_handler_generators: + try: + get_handler = discovered_handler_generators[handler_generator.name].load() + handler_registry.register(handler_generator.name, get_handler()) + except Exception: + log.warning( + "Exception while registering out of tree plugin %s: ", + handler_generator.name, + exc_info=True, + ) + + def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: """ Obtain a reference to a :py:class`RendezvousHandler`. diff --git a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py index 5d2679d9fb4a0c..e6395b70be2b43 100644 --- a/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py @@ -122,6 +122,7 @@ def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: timeout = int(params.config["timeout"]) else: timeout = _default_timeout_seconds + return StaticTCPRendezvous( master_addr, master_port, rank, world_size, run_id, timeout ) diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 74da756d58c99a..f762614a8e6c59 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -193,10 +193,11 @@ def __init__( def start(self) -> None: logger.info( - "Starting %s..." " max_interval=%s," " daemon=%s", + "Starting %s... max_interval=%s, daemon=%s, file_path=%s", type(self).__name__, self._max_interval, self._daemon, + self._file_path, ) self._watchdog_thread = threading.Thread( target=self._watchdog_loop, daemon=self._daemon @@ -268,7 +269,7 @@ def _run_watchdog(self, fd: io.TextIOWrapper) -> None: log_debug_info_for_expired_timers( self._run_id, { - pid: self._get_scopes(expired_timers) + pid: [expired_timer.to_json() for expired_timer in expired_timers] for pid, expired_timers in all_expired_timers.items() }, ) diff --git a/torch/distributed/elastic/utils/distributed.py b/torch/distributed/elastic/utils/distributed.py index 1a7ea81451f7c1..2e216da72c2547 100644 --- a/torch/distributed/elastic/utils/distributed.py +++ b/torch/distributed/elastic/utils/distributed.py @@ -7,7 +7,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import datetime -import functools +import os import socket from contextlib import closing from typing import Optional @@ -37,6 +37,16 @@ def create_c10d_store( retries=3, use_libuv: Optional[bool] = None, ): + if use_libuv is not None: + logger.warning( + "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " + 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' + "is not set, libuv will be used by default." + ) + + # check os.environ for use_libuv + use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option + if server_port == -1 and world_size > 1: raise ValueError( f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" @@ -68,20 +78,15 @@ def create_c10d_store( ) try: - store_builder = functools.partial( - dist.TCPStore, + store = dist.TCPStore( host_name=server_addr, port=port, world_size=world_size, is_master=is_server, timeout=datetime.timedelta(seconds=timeout), wait_for_workers=wait_for_workers, + use_libuv=use_libuv, ) - if use_libuv is None: - # TCPStore default backend may change, don't specify it unless we explicity told to do so. - store = store_builder() - else: - store = store_builder(use_libuv=use_libuv) # skips full rank check when we don't have to wait for all workers if wait_for_workers: _check_full_rank(store, world_size, timeout=timeout) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index d722d5b9825999..396f058c45a1ae 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -87,13 +87,15 @@ def __init__(self, device: torch.device, backend: Any = None): @classmethod def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle": """ - Return an device handle corresponding to the device, and through this handle, + Return a device handle corresponding to the device, and through this handle, operations with the same semantics as CUDA can be performed on the device. Just return torch.cuda if the device is cuda to make attribute-access faster. Custom backend must first register a module with the same name with {device.type} on torch. """ if device.type == "cuda": return cast(_FSDPDeviceHandle, torch.cuda) + elif device.type == "mtia": + return cast(_FSDPDeviceHandle, torch.mtia) return cls(device) def __getattr__(self, __name: str) -> Any: @@ -142,7 +144,7 @@ def __init__(self) -> None: self._gradient_postdivide_factor: int = 0 self._comm_hook: Optional[Callable] = None self._comm_hook_state: Optional[Any] = None - self._unshard_event: Optional[torch.cuda.Event] = None + self._unshard_event: Optional[torch.Event] = None # Abstract device handle for fsdp compute device. For now, # the compute device must implement cuda semantics used by fsdp self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() @@ -533,8 +535,12 @@ def forward_post_hook(module, args, output): def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: - # FIXME record_stream doesn't work with non-cuda tensors - if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]: + # FIXME record_stream doesn't work with non-cuda/mtia tensors + if tensor.device.type not in [ + "cuda", + "mtia", + torch._C._get_privateuse1_backend_name(), + ]: return if torch.distributed._functional_collectives.is_torchdynamo_compiling(): diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 8bc975dc72fd5a..ede7d06ec9a1de 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -1887,7 +1887,7 @@ def _use_unsharded_views(self, as_params: bool) -> None: flat_param = self.flat_param self._check_unsharded(flat_param) views = self._get_unflat_views() - from torch.distributed._tensor import DTensor + from torch.distributed.tensor import DTensor for i, (view, (param_name, module, _)) in enumerate( zip(views, flat_param._param_infos) @@ -2717,7 +2717,7 @@ def _warn_use_fake_reduce(log: logging.Logger, warning: str): def _same_storage(a, b): # Params are DTensors in backward # with SHARD_GRAD_OP + TP - from torch.distributed._tensor import DTensor + from torch.distributed.tensor import DTensor if isinstance(a, DTensor): a = a._local_tensor diff --git a/torch/distributed/fsdp/_fsdp_extensions.py b/torch/distributed/fsdp/_fsdp_extensions.py index 7d5d1e5f98a342..d96fa99a236571 100644 --- a/torch/distributed/fsdp/_fsdp_extensions.py +++ b/torch/distributed/fsdp/_fsdp_extensions.py @@ -5,12 +5,12 @@ import torch.distributed as dist from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.shard import Shard -from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed.fsdp._shard_utils import ( _all_gather_dtensor, _create_chunk_dtensor, _create_chunk_sharded_tensor, ) +from torch.distributed.tensor import DeviceMesh, DTensor class FSDPExtensions(ABC): diff --git a/torch/distributed/fsdp/_limiter_utils.py b/torch/distributed/fsdp/_limiter_utils.py index efb5b3ba5ae1f2..5cc56b29f84d4a 100644 --- a/torch/distributed/fsdp/_limiter_utils.py +++ b/torch/distributed/fsdp/_limiter_utils.py @@ -12,20 +12,20 @@ class _FreeEventQueue: """ def __init__(self) -> None: - self._queue: Deque[torch.cuda.Event] = collections.deque() + self._queue: Deque[torch.Event] = collections.deque() self._max_num_inflight_all_gathers = 2 # empirically chosen - def enqueue(self, free_event: torch.cuda.Event) -> None: + def enqueue(self, free_event: torch.Event) -> None: """Enqueues a free event.""" self._queue.append(free_event) - def dequeue_if_needed(self) -> Optional[torch.cuda.Event]: + def dequeue_if_needed(self) -> Optional[torch.Event]: """Dequeues a single event if the limit is reached.""" if len(self._queue) >= self._max_num_inflight_all_gathers: return self._dequeue() return None - def _dequeue(self) -> Optional[torch.cuda.Event]: + def _dequeue(self) -> Optional[torch.Event]: """Dequeues a free event if possible.""" if self._queue: event = self._queue.popleft() diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 4cfe761769a3b9..35beee36ef5839 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -27,7 +27,6 @@ import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.nn as nn from torch.distributed._state_dict_utils import _gather_state_dict -from torch.distributed._tensor import DTensor, Replicate from torch.distributed.distributed_c10d import _get_pg_default_device from torch.distributed.fsdp._common_utils import ( _apply_to_modules, @@ -53,6 +52,7 @@ StateDictSettings, StateDictType, ) +from torch.distributed.tensor import DTensor, Replicate from torch.utils._pytree import tree_map_only diff --git a/torch/distributed/fsdp/_shard_utils.py b/torch/distributed/fsdp/_shard_utils.py index b6cd1e108fb767..9769d5296d4efb 100644 --- a/torch/distributed/fsdp/_shard_utils.py +++ b/torch/distributed/fsdp/_shard_utils.py @@ -15,7 +15,7 @@ TensorProperties, ) from torch.distributed._shard.sharding_spec import ShardMetadata -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard def _get_remote_device_str(rank, device_type, num_devices_per_node): diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index a64c6413053ecc..f96872bfa6e7cf 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -25,7 +25,6 @@ Shard, ShardedTensor, ) -from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import _mesh_resources from torch.distributed.fsdp._common_utils import ( _FSDPState, @@ -49,6 +48,7 @@ ShardingStrategy, StateDictType, ) +from torch.distributed.tensor import DTensor from torch.distributed.utils import _replace_by_prefix from ._fsdp_extensions import ( @@ -730,13 +730,18 @@ def _post_state_dict_hook( for key, tensor in sorted(processed_state_dict.items()): if key.startswith(prefix) and isinstance(tensor, torch.Tensor): local_shape = tensor.shape + device = None if isinstance(tensor, ShardedTensor): local_shape = None shards = tensor.local_shards() if shards: local_shape = shards[0].tensor.shape + device = shards[0].tensor.device elif isinstance(tensor, DTensor): local_shape = tensor.to_local().shape + device = tensor.device + else: + device = tensor.device logger.info( "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", key, @@ -744,7 +749,7 @@ def _post_state_dict_hook( tensor.shape, local_shape, tensor.dtype, - tensor.device, + device, ) return processed_state_dict diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 92fa78e13ae9a1..3ff9e80ca12e3c 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -25,7 +25,6 @@ import torch.distributed as dist import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.nn as nn -from torch.distributed._tensor import DeviceMesh from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_WRAPPED_MODULE, ActivationWrapper, @@ -84,6 +83,7 @@ StateDictSettings, StateDictType, ) +from torch.distributed.tensor import DeviceMesh from torch.distributed.utils import _p_assert from ._flat_param import FlatParameter, FlatParamHandle @@ -1130,28 +1130,40 @@ def clip_grad_norm_( # where sharded and non-sharded parameters must be handled separately max_norm = float(max_norm) norm_type = float(norm_type) - sharded_params = set() - nonsharded_params = set() # `NO_SHARD` or not FSDP-managed + sharded_params_set = set() + nonsharded_params_set = set() # `NO_SHARD` or not FSDP-managed + # Make sure to compute the local norm using lists for deterministic + # iteration order and hence deterministic total norm computation + sharded_params = [] + nonsharded_params = [] grads: List[torch.Tensor] = [] for handle in self._all_handles: - target_set = ( - sharded_params if handle.uses_sharded_strategy else nonsharded_params - ) + if handle.uses_sharded_strategy: + target_set = sharded_params_set + target_list = sharded_params + else: + target_set = nonsharded_params_set + target_list = nonsharded_params if handle._use_orig_params: for param in handle.flat_param._params: - target_set.add(param) - if param.grad is not None: - grads.append(param.grad) + if param not in target_set: + target_set.add(param) + target_list.append(param) + if param.grad is not None: + grads.append(param.grad) else: - target_set.add(handle.flat_param) - if handle.flat_param.grad is not None: - grads.append(handle.flat_param.grad) + if handle.flat_param not in target_set: + target_set.add(handle.flat_param) + target_list.append(handle.flat_param) + if handle.flat_param.grad is not None: + grads.append(handle.flat_param.grad) for param in self.parameters(): not_fsdp_managed = ( - param not in sharded_params and param not in nonsharded_params + param not in sharded_params_set and param not in nonsharded_params_set ) if not_fsdp_managed: - nonsharded_params.add(param) + nonsharded_params_set.add(param) + nonsharded_params.append(param) if param.grad is not None: grads.append(param.grad) # Compute local norms (forced to be in FP32) @@ -2044,7 +2056,7 @@ class UnshardHandle: def __init__( self, flat_param_handle: Optional[FlatParamHandle], - unshard_event: torch.cuda.Event, + unshard_event: torch.Event, ): self._flat_param_handle = flat_param_handle self._unshard_event = unshard_event diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index 7c1b2f83528683..988ecb3533f5b9 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging from collections import abc, defaultdict -from typing import Any, Dict, Iterable, List, Optional, overload, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, overload, Tuple, Union import torch import torch.distributed as dist @@ -21,6 +21,7 @@ def _is_supported_device(tensor: torch.Tensor) -> bool: "xla", "cpu", "hpu", + "mtia", torch._C._get_privateuse1_backend_name(), ) @@ -167,36 +168,6 @@ def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): return apply_scale(outputs) - def _foreach_non_finite_check_and_unscale_cpu_( - self, - grads: Sequence[torch.Tensor], - found_inf: torch.Tensor, - inv_scale: torch.Tensor, - ) -> None: - if len(grads) == 0: - return - assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor." - assert found_inf.numel() == 1, "found_inf must be a 1-element tensor." - - for grad in grads: - if grad.device.type != "cpu": - logger.error( - "tensor device is %s but was expected to be ``cpu``", - grad.device, - ) - raise ValueError( - "Gradients were found on a non-CPU device when" - " expected to be on CPU." - ) - if ( - torch.isinf(grad).any().item() is True - or torch.isnan(grad).any().item() is True - ): - found_inf.data = torch.tensor([1.0]) - break - else: - grad.data *= inv_scale.item() - def _unscale_grads_( self, optimizer: torch.optim.Optimizer, @@ -240,18 +211,11 @@ def _unscale_grads_( for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - if grads[0].device.type == "cpu": - self._foreach_non_finite_check_and_unscale_cpu_( - grads, - per_device_found_inf.get(device), - per_device_inv_scale.get(device), - ) - else: - torch._amp_foreach_non_finite_check_and_unscale_( - grads, - per_device_found_inf.get(device), - per_device_inv_scale.get(device), - ) + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some # ranks may have no (non-zero sized) parameter shards, necessitating the # initialization of `per_device_found_inf._per_device_tensors` here diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py index db65856e32adad..3c0027d1124073 100644 --- a/torch/distributed/optim/post_localSGD_optimizer.py +++ b/torch/distributed/optim/post_localSGD_optimizer.py @@ -96,7 +96,7 @@ def load_state_dict(self, state_dict): ) self.averager.step = 0 - def step(self): + def step(self): # type: ignore[override] r""" Performs a single optimization step (parameter update). """ diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 7b6e39cfbe44bb..3010bccd377c94 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -13,7 +13,6 @@ import torch.fx as fx from torch.distributed import ProcessGroup from torch.export import ExportedProgram -from torch.export._trace import _export from torch.export.unflatten import ( _assign_attr, _AttrKind, @@ -365,7 +364,7 @@ def __init__(self, module, garbage_collect_values=True): super().__init__(module, garbage_collect_values) self.value_remap = {} - def run(self, *args, initial_env=None): + def run(self, *args, initial_env=None): # type: ignore[override] self.value_remap = {} return super().run(*args, initial_env=initial_env) @@ -933,8 +932,7 @@ def move_param_to_callee( if node.op == "get_attr": # get_attr might get access deeper level attribute fqn = scope + "." + node.target if scope else node.target - if fqn in unused_attributes: # used, remove it - unused_attributes.remove(fqn) + unused_attributes.discard(fqn) for _name, _submod in _mod.named_children(): stack.append((scope + "." + _name if scope else _name, _submod)) # delete unused attributes @@ -1005,14 +1003,11 @@ def _trace_with_export( ) -> ExportedProgram: logger.info("Tracing model ...") try: - with torch.no_grad(): - ep = _export( - mod, - example_args, - example_kwargs, - strict=True, - pre_dispatch=False, - ) + ep = torch.export.export( + mod, + example_args, + example_kwargs, + ) except Exception as e: raise RuntimeError( "It seems that we cannot capture your model as a full graph. " diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index 5aaa6180721d97..a4c5516e83da01 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -159,7 +159,7 @@ def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, A def stage_backward_input( stage_outputs: List[torch.Tensor], output_grads: Optional[List[torch.Tensor]], - stage_inputs: List[torch.Tensor], + input_values: List[torch.Tensor], weights: Iterator[Parameter], ): """ @@ -169,7 +169,7 @@ def stage_backward_input( filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs)) ) stage_input_grad_fns: List[Node] = list( - filter(None, map(_get_grad_fn_or_grad_acc, stage_inputs)) + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) ) weight_grad_fns: List[Node] = list( filter(None, map(_get_grad_fn_or_grad_acc, weights)) @@ -197,7 +197,7 @@ def hook(grad_inputs): intermediate.register_prehook(get_hook(param_group, i)) # Stage 0 inputs do not require grads? Should we skip in that case? - if all(tensor.requires_grad for tensor in stage_inputs): + if all(tensor.requires_grad for tensor in input_values): if output_grads is None: # In case this is the loss and there are no output_grads, then we just use 1s output_grads = [ @@ -206,13 +206,13 @@ def hook(grad_inputs): dinputs = torch.autograd.grad( stage_outputs, - inputs=stage_inputs, + inputs=input_values, grad_outputs=output_grads, retain_graph=True, ) # update the gradients for inputs - for i, inp in enumerate(stage_inputs): + for i, inp in enumerate(input_values): if inp.grad is None: inp.grad = dinputs[i] else: @@ -225,7 +225,14 @@ def hook(grad_inputs): def stage_backward_weight( weights: Iterator[Parameter], param_groups: List[Dict[str, Any]] ): - all_dweights = dict() + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + for param_group in param_groups: # TODO: Handle case where intermediate can have multiple outputs intermediate_edges = tuple( @@ -242,19 +249,14 @@ def stage_backward_weight( weights_edges, grad_outputs=sum(param_group["grads"], tuple()), ) - for w, dw in zip(param_group["params"], dweights): - all_dweights[w] = dw + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw # return grads in the original order weights were provided in - out = [] - for w in weights: - grad_acc = _get_grad_fn_or_grad_acc(w) - dweight = all_dweights[grad_acc] - out.append(dweight) - if w.grad is None: - w.grad = dweight - else: - w.grad += dweight - return out + return weight_grads def stage_backward( diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 4761ba8dcb7787..3c4abfb7b0b352 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -10,7 +10,7 @@ import torch.fx as fx import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensor -from torch.distributed._composable.fsdp.fully_shard import FSDPModule +from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard from torch.fx.node import map_aggregate from torch.nn.parallel import DistributedDataParallel @@ -468,14 +468,40 @@ def forward_maybe_with_nosync(self, *args, **kwargs): out_val = self.submod(*args, **kwargs) return out_val - def backward_maybe_with_nosync(self, bwd_kwargs: Dict): + def backward_maybe_with_nosync(self, backward_type, bwd_kwargs: Dict): """ Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but there are additional state-variables and performance considerations depending on the data parallelism used. This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. """ - last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] + full_backward = bwd_kwargs["full_backward"] + if full_backward: + last_backward = self._seen_bwd_chunks == self.chunks - 1 # type: ignore[operator] + else: + # For backwards are split into weight and input, we will see twice as many bwd_chunks + last_backward = self._seen_bwd_chunks == 2 * self.chunks - 1 # type: ignore[operator] + + def perform_backward(backward_type): + if backward_type == "full": + return lambda: stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") # If submod is wrapped by DDP if isinstance(self.submod, DistributedDataParallel): @@ -489,21 +515,41 @@ def backward_maybe_with_nosync(self, bwd_kwargs: Dict): ) ) ) - grads_input = stage_backward(**bwd_kwargs) + result = perform_backward(backward_type)() else: with self.submod.no_sync(): # type: ignore[operator] - grads_input = stage_backward(**bwd_kwargs) + result = perform_backward(backward_type)() # If submod is a FSDP module elif isinstance(self.submod, FSDPModule): - self.submod.set_is_last_backward(last_backward) - self.submod.set_requires_gradient_sync(last_backward) - grads_input = stage_backward(**bwd_kwargs) + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + if last_backward: + # Manually call post backward for FSDP + def run_post_backward(fsdp_module: FSDPModule) -> None: + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + fsdp_state = fully_shard.state(fsdp_module) + for state in fsdp_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + run_post_backward(self.submod) else: # Non-DP submodule, regular backward - grads_input = stage_backward(**bwd_kwargs) + result = perform_backward(backward_type)() self._seen_bwd_chunks += 1 - return grads_input + + if isinstance(result, tuple) and len(result) == 2: + # for stage_backward_input() + grads, param_groups = result + else: + grads, param_groups = result, None + + return grads, param_groups def forward_one_chunk( self, @@ -611,41 +657,43 @@ def backward_one_chunk( "input_values": input_values, } + # Save full_backward + bwd_kwargs["full_backward"] = full_backward + # Custom backward function if self.dw_builder: # TODO: We may want to change our semantics so we are allowed to ignore # the 'dw_builder' and call full_backward directly when it is a full_backward op. - self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs) + self.grads_input, _ = self.backward_maybe_with_nosync("full", bwd_kwargs) if full_backward: self.dw_builder()() else: self.dw_runner[bwd_chunk_id] = self.dw_builder() else: if full_backward: - self.grads_input = self.backward_maybe_with_nosync(bwd_kwargs) + self.grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs + ) else: # perform the partial backwards for the inputs with a custom backward function # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors if isinstance(bwd_kwargs["stage_output"], torch.Tensor): bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) - dinputs, param_groups = stage_backward_input( - bwd_kwargs["stage_output"], - bwd_kwargs["output_grads"], - bwd_kwargs["input_values"], - self.submod.parameters(), + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs ) + # TODO: we dont need to save this, add to dw_runner? self.backward_state[bwd_chunk_id] = ( - dinputs, input_values, param_groups, bwd_kwargs["stage_output"], bwd_kwargs["output_grads"], ) - self.grads_input = dinputs - # save a dw_runner with the `stage_backward_weight` function - self.dw_runner[bwd_chunk_id] = stage_backward_weight + self.grads_input = grads_input + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) def backward_weight_one_chunk(self, bwd_chunk_id: int): @@ -658,23 +706,32 @@ def backward_weight_one_chunk(self, bwd_chunk_id: int): self.dw_runner.pop(bwd_chunk_id)() else: ( - dinputs, input_values, param_groups, stage_output, output_grads, ) = self.backward_state.pop(bwd_chunk_id) + if self.stage_index != 0: - dweights = self.dw_runner.pop(bwd_chunk_id)( - self.submod.parameters(), param_groups - ) + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + "full_backward": False, + } + weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs) else: # TODO: figure out a better way to do this: # if inputs does not require gradient, # then the parameter group will not be fully captured during stage_backward_input # in this case, we need call grad directly on the parameters # To solve: make input fn do the intersect compute and then finish it off during W - torch.autograd.backward(stage_output, grad_tensors=output_grads) + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + "full_backward": False, + } + self.backward_maybe_with_nosync("full", bwd_kwargs) def _validate_fwd_input(self, args, kwargs): """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" diff --git a/torch/distributed/_tensor/README.md b/torch/distributed/tensor/README.md similarity index 100% rename from torch/distributed/_tensor/README.md rename to torch/distributed/tensor/README.md diff --git a/torch/distributed/tensor/__init__.py b/torch/distributed/tensor/__init__.py index e69de29bb2d1d6..f2746f60ba8720 100644 --- a/torch/distributed/tensor/__init__.py +++ b/torch/distributed/tensor/__init__.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +import torch.distributed.tensor._ops # force import all built-in dtensor ops +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.tensor._api import ( + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + ones, + rand, + randn, + zeros, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", + "Partial", + "Placement", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + + +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) + + +# Set namespace for exposed private names +DTensor.__module__ = "torch.distributed.tensor" +distribute_tensor.__module__ = "torch.distributed.tensor" +distribute_module.__module__ = "torch.distributed.tensor" +ones.__module__ = "torch.distributed.tensor" +empty.__module__ = "torch.distributed.tensor" +full.__module__ = "torch.distributed.tensor" +rand.__module__ = "torch.distributed.tensor" +randn.__module__ = "torch.distributed.tensor" +zeros.__module__ = "torch.distributed.tensor" diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py new file mode 100644 index 00000000000000..8684d7b0cafa0f --- /dev/null +++ b/torch/distributed/tensor/_api.py @@ -0,0 +1,1234 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import inspect +import warnings +from typing import Any, Callable, cast, Optional, Sequence, Tuple + +import torch +import torch.distributed.tensor._dispatch as op_dispatch +import torch.distributed.tensor._random as random +import torch.nn as nn +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._random import ( + is_rng_supported_mesh, + OffsetBasedRNGTracker, +) +from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed.tensor._utils import ( + compute_global_tensor_info, + compute_local_shape_and_global_offset, + normalize_to_torch_size, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure DTensor to work +# together with torch.Tensor within the autograd engine. This +# allows DTensor to only exist on part of the module hierarchy. +# +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DTensor params, the following forward/backward should work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input (from_local) -> Sharded Module B -> DTensor output +# -> torch.Tensor output (to_local) -> Module C +# +# So from_local/to_local must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Optional[Sequence[Placement]], + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + DTensor( + grad_output, + grad_spec, + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + run_check: bool, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + dist_tensor = DTensor( + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): + """ + ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like + abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding + layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: + + * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension + * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension + * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension + + When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue + communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the + placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. + + To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` + requires every Tensor argument of the operator be DTensor. + + """ + + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # _op_dispatcher instance as a class attribute to handle runtime dispatching logic + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + + .. note:: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using ``DTensor.from_local``, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using ``distribute_tensor``. + """ + if local_tensor.requires_grad and not requires_grad: + warnings.warn( + "To construct DTensor from torch.Tensor, it's recommended to " + "use local_tensor.detach() and make requires_grad consistent." + ) + + # new method instruct wrapper tensor from local_tensor and add + # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, + ) + + r._spec = spec + r._local_tensor = local_tensor + return r + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): # type: ignore[override] + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert ( + flatten_spec is not None + ), "Expecting spec to be not None from `__tensor_flatten__` return value!" + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec): + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + @torch._disable_dynamo + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + run_check: bool = False, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks (i.e. the tensor is sharded for + the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). + If not, the behavior of the created DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the + local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, + it means the local tensor is not ready yet (i.e. communication is not finished). In this + case, user needs to call ``wait`` to wait the local tensor to be ready. + + .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + async_op: bool = False, + ) -> "DTensor": + """ + ``redistribute`` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from is current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + When redistributing from current to the new placements on one device mesh dimension, we + will perform the following operations including communication collective or local operation: + + 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` + 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` + 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) + 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` + 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` + + + ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors + that are created either on 1-D or N-D DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. If not specified, it would use the current DTensor's DeviceMesh. + default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + default: replicate on all mesh dimensions + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + + Returns: + A :class:`DTensor` object + + .. note:: ``redistribute`` is differentiable, which means user do not need to worry about + the backward formula of the redistribute operation. + + .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, + Please file an issue if you need to redistribute DTensor to different DeviceMesh. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial(): + raise RuntimeError( + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply(self, device_mesh, placements, async_op) + + def full_tensor( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntatic sugar of the following code: + + ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: ``full_tensor`` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: ``device_mesh`` is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> Tuple[Placement, ...]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: ``placements`` is a read-only property, it can not be set. + """ + return self._spec.placements + + def __create_write_items__(self, fqn: str, object: Any): + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the + same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use + the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve + the single-device semantic. If you want to construct a DTensor in the middle of the Autograd + computation, please use :meth:`DTensor.from_local` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. The uneven sharding + behavior is experimental and subject to change. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as ``device_mesh.ndim``. If not specified, we will + by default replicate the tensor across the ``device_mesh`` from the + first rank of each dimension of the `device_mesh`. + + Returns: + A :class:`DTensor` or ``XLAShardedTensor`` object. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` + return `XLAShardedTensor` instead. see `this issue `__ + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor( + tensor, device_mesh, placements + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + # TODO: the value assignment to global variable is not the ideal solution + # we can replace it in future. + if not random._rng_tracker and is_rng_supported_mesh(device_mesh): + random._rng_tracker = OffsetBasedRNGTracker(device_type) + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # TODO(xilun): address sharding order + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + # normalize shard placement dim + placement = Shard(placement.dim + tensor.ndim) + placements[idx] = placement + local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) + elif placement.is_replicate(): + placement = cast(Replicate, placement) + local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + return DTensor( + local_tensor.requires_grad_(tensor.requires_grad), + spec, + requires_grad=tensor.requires_grad, + ) + + +def distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, + output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, +) -> nn.Module: + """ + This function expose three functions to control the parameters/inputs/outputs of the module: + + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to ``torch.Tensor``) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the ``device_mesh``). If ``partition_fn`` is not specified, + by default we replicate all module parameters of ``module`` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. ``input_fn`` will be installed as a module + ``forward_pre_hook`` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be + installed as a module ``forward_hook`` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all ``DTensor`` s. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See + `this issue `__ + for more details. The XLA integration is experimental and subject to change. + + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for name, submod in module.named_modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + return module + + +# Below are tensor factory function APIs, which are used to create a DTensor directly. We need +# to make separate factory function APIs because tensor subclass could not override the tensor +# factory methods, and we need user to call the factory functions with user intended device_mesh +# and placements to create a proper DTensor. + + +def _dtensor_init_helper( # type: ignore[no-untyped-def] + init_op, + size: torch.Size, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + **kwargs, +) -> DTensor: + # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta + + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh againts placements + assert device_mesh.ndim == len( + placements + ), "mesh dimension does not match the length of placements" + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape, _ = compute_local_shape_and_global_offset( + size, device_mesh, placements + ) + + # initialize the local tensor + if init_op == torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op == torch.rand or init_op == torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker() + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) + + +def ones( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( # type: ignore[no-untyped-def] + size, + fill_value, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and + ``placements``, with the shape defined by the argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py similarity index 97% rename from torch/distributed/_tensor/_collective_utils.py rename to torch/distributed/tensor/_collective_utils.py index a59aa44492fa77..34795858356ab8 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -7,7 +7,7 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.placement_types as placement_types +import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch._C._distributed_c10d import _resolve_process_group from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.distributed_c10d import ( @@ -202,7 +202,7 @@ def fill_empty_tensor_to_shards( def check_tensor_meta( local_tensor, check_shape_stride=False -) -> Optional["placement_types.TensorMeta"]: +) -> Optional["dtensor_spec.TensorMeta"]: local_metadata = { "dtype": local_tensor.dtype, "requires_grad": local_tensor.requires_grad, @@ -224,7 +224,7 @@ def check_tensor_meta( return None -def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int: +def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: assert spec.tensor_meta is not None, "spec should have tensor meta defined!" return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) @@ -311,8 +311,8 @@ def reduce_scatter_cost( def redistribute_cost( - current_spec: "placement_types.DTensorSpec", - target_spec: "placement_types.DTensorSpec", + current_spec: "dtensor_spec.DTensorSpec", + target_spec: "dtensor_spec.DTensorSpec", ) -> float: """ This function returns the cost of redistribute from current to target DTensorSpec. diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py similarity index 78% rename from torch/distributed/_tensor/_dispatch.py rename to torch/distributed/tensor/_dispatch.py index 129e556e747298..4579a16826d0fb 100644 --- a/torch/distributed/_tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -8,24 +8,25 @@ import torch import torch.distributed as dist -import torch.distributed._tensor.api as dtensor -import torch.distributed._tensor.random as random -from torch.distributed._tensor._op_schema import ( +import torch.distributed.tensor._api as dtensor +import torch.distributed.tensor._random as random +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpInfo, OpSchema, OutputSpecType, ) -from torch.distributed._tensor._redistribute import redistribute_local_tensor -from torch.distributed._tensor._sharding_prop import ShardingPropagator -from torch.distributed._tensor._tp_conv import ( +from torch.distributed.tensor._random import is_rng_supported_mesh +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor._sharding_prop import ShardingPropagator +from torch.distributed.tensor._tp_conv import ( convolution_backward_handler, convolution_handler, ) -from torch.distributed._tensor._utils import try_find_mesh_from_args -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, TensorMeta -from torch.distributed._tensor.random import is_rng_supported_mesh +from torch.distributed.tensor._utils import try_find_mesh_from_args +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate if TYPE_CHECKING: @@ -66,6 +67,46 @@ def is_same_size_handler( return lhs.shape == rhs.shape +def found_inf_reduce_handler( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], +) -> None: + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + local_tensor_args = pytree.tree_unflatten( + cast(List[object], op_info.local_args), op_info.args_tree_spec + ) + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] + grad_placements = grad_dtensor.placements + mesh = grad_dtensor.device_mesh + + found_inf_placements: list[Placement] = [] + for placement in grad_placements: + if isinstance(placement, Replicate): + found_inf_placements.append(placement) + else: + found_inf_placements.append(Partial("max")) + + target_tensor = cast(torch.Tensor, args[1]) + spec = DTensorSpec( + mesh=mesh, + placements=tuple(found_inf_placements), + tensor_meta=TensorMeta( + shape=target_tensor.size(), + stride=target_tensor.stride(), + dtype=target_tensor.dtype, + ), + ) + found_inf_dtensor = dtensor.DTensor( + local_tensor=target_tensor, spec=spec, requires_grad=False + ) + found_inf = found_inf_dtensor.full_tensor() + target_tensor.copy_(found_inf) + + class OpDispatcher: """ Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding @@ -99,6 +140,7 @@ def __init__(self) -> None: aten.is_same_size.default: is_same_size_handler, aten.convolution.default: convolution_handler, aten.convolution_backward.default: convolution_backward_handler, + aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, } # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) @@ -309,16 +351,20 @@ def unwrap_to_op_info( for arg in args_list: if isinstance(arg, dtensor.DTensor): - args_schema.append(arg._spec) local_args.append(arg._local_tensor) - if mesh is not None: - if mesh != arg.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - f"Got meshes: {mesh} {arg.device_mesh}" - ) + if mesh is not None and mesh != arg.device_mesh: + # TODO: try replicate dtensor spec in missing dimension would work + # for most cases for foreach case except when the first DTensor in + # the list is one that also need to be replicated. We need to revisit + # how we want to handle this corner case. For now, this case would hit + # the cross mesh error even if implicit replication is turned on. + spec = self._try_replicate_dtensor_spec_in_missing_dim( + op_call, arg, mesh + ) + args_schema.append(spec) else: mesh = arg.device_mesh + args_schema.append(arg._spec) elif isinstance(arg, torch.Tensor): mesh = mesh or try_find_mesh_from_args(op_call, args_list) args_schema.append( @@ -331,15 +377,15 @@ def unwrap_to_op_info( for k, v in kwargs.items(): if isinstance(v, dtensor.DTensor): - kwargs_schema[k] = v._spec local_kwargs[k] = v._local_tensor - if mesh is not None: - if mesh != v.device_mesh: - raise NotImplementedError( - f"{op_call}: DTensor does not support cross-mesh operation yet!" - ) + if mesh is not None and mesh != v.device_mesh: + spec = self._try_replicate_dtensor_spec_in_missing_dim( + op_call, v, mesh + ) + kwargs_schema[k] = spec else: mesh = v.device_mesh + kwargs_schema[k] = v._spec elif isinstance(v, torch.Tensor): mesh = mesh or try_find_mesh_from_args(op_call, args_list) kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( @@ -426,3 +472,39 @@ def _try_replicate_spec_for_scalar_tensor( " torch.Tensor to DTensor before calling distributed operators!" ) return replication_spec + + def _try_replicate_dtensor_spec_in_missing_dim( + self, + op_call: torch._ops.OpOverload, + dtensor_arg: "dtensor.DTensor", + mesh: "DeviceMesh", + ) -> DTensorSpec: + # util function to produce a new spec for a DTensor arg/kwarg + # that puts Replicate() placement in the missing dimension for foreach ops + from torch.distributed.device_mesh import _mesh_resources + + cur_mesh = dtensor_arg.device_mesh + root_mesh = _mesh_resources.get_root_mesh(cur_mesh) + if ( + self._allow_implicit_replication + and "foreach" in op_call.__name__ + and root_mesh == mesh + ): + placements = [Replicate() for _ in range(root_mesh.ndim)] + cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh) + placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload] + replicate_spec = DTensorSpec( + root_mesh, + tuple(placements), + tensor_meta=TensorMeta( + shape=dtensor_arg.shape, + stride=dtensor_arg.stride(), + dtype=dtensor_arg.dtype, + ), + ) + else: + raise NotImplementedError( + f"{op_call}: DTensor does not support cross-mesh operation yet! " + f"Got meshes: {mesh} {cur_mesh}" + ) + return replicate_spec diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py new file mode 100644 index 00000000000000..e80729c7b62869 --- /dev/null +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -0,0 +1,276 @@ +from dataclasses import dataclass +from typing import Any, cast, List, NamedTuple, Optional, Tuple + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: Tuple[int, ...] + dtype: torch.dtype + + +# used internally to propagate the placements +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: Tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: Optional[TensorMeta] = None + + def __post_init__(self) -> None: + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + self._hash: Optional[int] = None + + def __setattr__(self, attr: str, value: Any) -> None: + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh` or `placements` to change) + if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): + self._hash = None + + def _hash_impl(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them. + if self.tensor_meta is not None: + return hash( + ( + self.mesh, + self.placements, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + ) + return hash((self.mesh, self.placements)) + + def __hash__(self) -> int: + # We lazily cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. This must be lazy so that Dynamo + # does not try to hash non-singleton `SymInt`s for the stride. + if self._hash is None: + self._hash = self._hash_impl() + return self._hash + + def __eq__(self, __o: object) -> bool: + if not ( + isinstance(__o, DTensorSpec) + and self.mesh == __o.mesh + and self.placements == __o.placements + ): + return False + if self.tensor_meta is None or __o.tensor_meta is None: + return self.tensor_meta == __o.tensor_meta + + return ( + self.tensor_meta.shape == __o.tensor_meta.shape # type: ignore[union-attr] + and self.tensor_meta.stride == __o.tensor_meta.stride # type: ignore[union-attr] + and self.tensor_meta.dtype == __o.tensor_meta.dtype # type: ignore[union-attr] + ) + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + if len(self.placements) == 1: + placement_str = str(self.placements[0]) + else: + placement_str = str(self.placements) + + if self.tensor_meta is not None: + tensor_shape = str(tuple(self.tensor_meta.shape)) + else: + tensor_shape = "unknown shape" + + return f"Spec({placement_str} on {tensor_shape})" + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def stride(self) -> Tuple[int, ...]: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.stride + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def device_mesh(self) -> DeviceMesh: + # simple aliasing for the mesh field, make some + # checks that mixes DTensor/DTensorSpec easier + return self.mesh + + @property + def dim_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def num_shards_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. Unlike `dim_map`, `num_shards_map` + denotes how many shards each tensor dim has. Like `dim_map`: + len(num_shards_map) == dist_tensor.ndim + num_shards_map[i] = 1: means tensor dim i is not sharded + num_shards_map[i] = j: means tensor dim i has j shards in total + + For example, we have a dist tensor of shape [18, 20, 30], + a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements + ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor + would be: [4, 2, 1]. + """ + r = [1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + r[shard_dim] *= self.mesh.size(i) + + return r + + @property + def sums(self) -> List[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: List[int], + sums: List[int], + tensor_meta: Optional[TensorMeta] = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self) -> bool: + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_sharded(self) -> bool: + """ + return True if the current DTensorSpec is sharded on any mesh dims (devices) + """ + return any(placement.is_shard() for placement in self.placements) + + def shallow_copy_with_tensor_meta( + self, tensor_meta: Optional[TensorMeta] + ) -> "DTensorSpec": + """ + Shallow copy the DTensorSpec with a new tensor_meta. + """ + assert tensor_meta is not None, "shallow copy with no tensor_meta!" + return DTensorSpec( + self.mesh, + self.placements, + tensor_meta=tensor_meta, + ) diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py similarity index 99% rename from torch/distributed/_tensor/_op_schema.py rename to torch/distributed/tensor/_op_schema.py index b2461086149929..190886da21fd22 100644 --- a/torch/distributed/_tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -5,8 +5,9 @@ import torch from torch._ops import OpOverload -from torch.distributed._tensor.placement_types import DTensorSpec, Placement from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement try: diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/tensor/_ops/__init__.py similarity index 100% rename from torch/distributed/_tensor/ops/__init__.py rename to torch/distributed/tensor/_ops/__init__.py diff --git a/torch/distributed/_tensor/ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py similarity index 97% rename from torch/distributed/_tensor/ops/_common_rules.py rename to torch/distributed/tensor/_ops/_common_rules.py index f70b27076de7c4..059dd04bd2f4d4 100644 --- a/torch/distributed/_tensor/ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -2,15 +2,15 @@ from typing import cast, Dict, List, Optional, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, OutputSharding, ) -from torch.distributed._tensor._utils import compute_local_shape -from torch.distributed._tensor.ops.utils import prod -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops.utils import prod +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: @@ -171,7 +171,7 @@ def merge_sharding(dim: str, a: int, b: int) -> int: ): assert input_spec.tensor_meta is not None global_shape = input_spec.tensor_meta.shape - local_shape = compute_local_shape( + local_shape, _ = compute_local_shape_and_global_offset( global_shape, input_spec.mesh, input_spec.placements ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) diff --git a/torch/distributed/_tensor/ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py similarity index 93% rename from torch/distributed/_tensor/ops/_conv_ops.py rename to torch/distributed/tensor/_ops/_conv_ops.py index 3a3b743dc4710e..db2a8136e14da0 100644 --- a/torch/distributed/_tensor/ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -4,9 +4,9 @@ from typing import List import torch -from torch.distributed._tensor._op_schema import OpSchema, OutputSharding -from torch.distributed._tensor.ops.utils import register_prop_rule -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._ops.utils import register_prop_rule aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py similarity index 97% rename from torch/distributed/_tensor/ops/_einsum_strategy.py rename to torch/distributed/tensor/_ops/_einsum_strategy.py index 97dd43b1524dc8..fc3227600b35d2 100644 --- a/torch/distributed/_tensor/ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -2,15 +2,15 @@ from dataclasses import dataclass from typing import List, Set, Tuple -from torch.distributed._tensor._op_schema import OpStrategy, PlacementStrategy -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpStrategy, PlacementStrategy +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh @dataclass diff --git a/torch/distributed/_tensor/ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_embedding_ops.py rename to torch/distributed/tensor/_ops/_embedding_ops.py index 15b2af2a01c17b..ae333b800ffcb8 100644 --- a/torch/distributed/_tensor/ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -7,23 +7,23 @@ import torch import torch.distributed._functional_collectives as funcol -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, StrategyType, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_experimental_ops.py b/torch/distributed/tensor/_ops/_experimental_ops.py similarity index 69% rename from torch/distributed/_tensor/ops/_experimental_ops.py rename to torch/distributed/tensor/_ops/_experimental_ops.py index 1edfb8d294ffea..d4ab5cc3aeeb02 100644 --- a/torch/distributed/_tensor/ops/_experimental_ops.py +++ b/torch/distributed/tensor/_ops/_experimental_ops.py @@ -3,15 +3,16 @@ # implement matrix related ops for distributed tensor import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.ops.utils import register_op_strategy -from torch.distributed._tensor.placement_types import DTensorSpec, Replicate +from torch.distributed.tensor._ops.utils import register_op_strategy +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Replicate aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py similarity index 99% rename from torch/distributed/_tensor/ops/_math_ops.py rename to torch/distributed/tensor/_ops/_math_ops.py index bfa2622c96cade..4905c338918595 100644 --- a/torch/distributed/_tensor/ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -7,7 +7,9 @@ from typing import cast, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, @@ -15,8 +17,7 @@ RuntimeSchemaInfo, TupleStrategy, ) -from torch.distributed._tensor._utils import normalize_to_torch_size -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( as_list, expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -25,14 +26,13 @@ normalize_dims, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor._utils import normalize_to_torch_size +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_matrix_ops.py rename to torch/distributed/tensor/_ops/_matrix_ops.py index 9d63dd53389987..fd9a7a430a70eb 100644 --- a/torch/distributed/_tensor/ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -5,15 +5,17 @@ from typing import List import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, RuntimeSchemaInfo, ) -from torch.distributed._tensor.ops._einsum_strategy import gen_einsum_strategies -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, infer_broadcast_dims_map, @@ -21,13 +23,7 @@ map_placements_after_broadcast, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, - Placement, - Replicate, - Shard, -) -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_pointwise_ops.py rename to torch/distributed/tensor/_ops/_pointwise_ops.py index f3f32805dca065..bb40865ed9c06a 100644 --- a/torch/distributed/_tensor/ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -2,7 +2,9 @@ from typing import List, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( _is_inplace_op, _is_out_variant_op, OpSchema, @@ -12,21 +14,19 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, register_op_strategy, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten @@ -239,7 +239,12 @@ aten.igammac.default, aten.igammac.out, aten.igammac_.default, + aten.isinf.default, aten.isnan.default, + aten.isneginf.default, + aten.isneginf.out, + aten.isposinf.default, + aten.isposinf.out, aten.ldexp.default, aten.ldexp.out, aten.ldexp_.default, @@ -580,6 +585,7 @@ def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> Strategy aten._foreach_cos_.default, aten._foreach_log.default, aten._foreach_log_.default, + aten._amp_foreach_non_finite_check_and_unscale_.default, ] for_each_linearity_ops = [ diff --git a/torch/distributed/_tensor/ops/_random_ops.py b/torch/distributed/tensor/_ops/_random_ops.py similarity index 90% rename from torch/distributed/_tensor/ops/_random_ops.py rename to torch/distributed/tensor/_ops/_random_ops.py index f54cfea7b00509..726b25e1eed023 100644 --- a/torch/distributed/_tensor/ops/_random_ops.py +++ b/torch/distributed/tensor/_ops/_random_ops.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-decorators # Copyright (c) Meta Platforms, Inc. and affiliates import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, StrategyType, ) -from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_tensor_ops.py rename to torch/distributed/tensor/_ops/_tensor_ops.py index 335773f4f8a286..e9bcb3b0d12240 100644 --- a/torch/distributed/_tensor/ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -4,7 +4,9 @@ from typing import cast, List, Optional, Sequence, Tuple import torch -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( _is_inplace_op, OpSchema, OpStrategy, @@ -15,9 +17,9 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops._common_rules import pointwise_rule -from torch.distributed._tensor.ops._embedding_ops import _MaskPartial -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops._common_rules import pointwise_rule +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, is_tensor_dim_sharded, is_tensor_evenly_shardable, @@ -26,14 +28,12 @@ register_op_strategy, register_prop_rule, ) -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten diff --git a/torch/distributed/_tensor/ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py similarity index 98% rename from torch/distributed/_tensor/ops/_view_ops.py rename to torch/distributed/tensor/_ops/_view_ops.py index 0d7299544d4e4e..451b92c80b24b8 100644 --- a/torch/distributed/_tensor/ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -17,23 +17,23 @@ import torch from torch import Tensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementStrategy, RuntimeSchemaInfo, StrategyType, ) -from torch.distributed._tensor.api import Shard -from torch.distributed._tensor.ops.utils import ( +from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, register_op_strategy, ) -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard aten = torch.ops.aten @@ -377,7 +377,7 @@ def view_groups(from_size: Shape, to_size: Shape) -> DimMap: if len(to_group_shape) > 0: flattened = Flatten.new( - tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] > 1) + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1) ) result_pp += [ Split.new(flattened, tuple(to_group_shape), i) diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/tensor/_ops/utils.py similarity index 96% rename from torch/distributed/_tensor/ops/utils.py rename to torch/distributed/tensor/_ops/utils.py index 31a43b5ad8a80b..334bcc4a37fea3 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -6,18 +6,18 @@ from typing import cast, Iterable, List, Optional, Sequence, Tuple, Union import torch -from torch.distributed._tensor._collective_utils import redistribute_cost -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( OpSchema, OpStrategy, PlacementList, PlacementStrategy, RuntimeSchemaInfo, ) -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, diff --git a/torch/distributed/_tensor/random.py b/torch/distributed/tensor/_random.py similarity index 98% rename from torch/distributed/_tensor/random.py rename to torch/distributed/tensor/_random.py index ac3771be0ba2f0..db4b2832548c07 100644 --- a/torch/distributed/_tensor/random.py +++ b/torch/distributed/tensor/_random.py @@ -7,8 +7,9 @@ import torch import torch.distributed as dist from torch import Tensor -from torch.distributed._tensor.placement_types import DTensorSpec, Shard from torch.distributed.device_mesh import _get_device_handle, DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Shard __all__ = [ @@ -290,7 +291,7 @@ def _set_pre_op_offset(self, spec: DTensorSpec) -> None: return_offset=False, )[0] - from torch.distributed._tensor.ops.utils import prod + from torch.distributed.tensor._ops.utils import prod local_size = prod(local_size_on_rank_0) @@ -317,7 +318,7 @@ def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: """ dtensor_shape = spec.shape - from torch.distributed._tensor.ops.utils import prod + from torch.distributed.tensor._ops.utils import prod numel = prod(dtensor_shape) # pytorch: offset must be multiple of 4 diff --git a/torch/distributed/_tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py similarity index 94% rename from torch/distributed/_tensor/_redistribute.py rename to torch/distributed/tensor/_redistribute.py index 127bf3e9857d74..88414081a1785b 100644 --- a/torch/distributed/_tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -6,15 +6,14 @@ import torch import torch.distributed._functional_collectives as funcol -import torch.distributed._tensor.api as dtensor -from torch.distributed._tensor.device_mesh import DeviceMesh -from torch.distributed._tensor.placement_types import ( - DTensorSpec, +import torch.distributed.tensor._api as dtensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( Partial, Placement, Replicate, Shard, - TensorMeta, ) @@ -28,8 +27,7 @@ class _TransformInfo(NamedTuple): logical_shape: List[int] -@lru_cache(maxsize=None) -def _gen_transform_infos( +def _gen_transform_infos_non_cached( src_spec: DTensorSpec, dst_spec: DTensorSpec, ) -> List[_TransformInfo]: @@ -147,6 +145,14 @@ def _gen_transform_infos( return transform_infos +@lru_cache(maxsize=None) +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> List[_TransformInfo]: + return _gen_transform_infos_non_cached(src_spec, dst_spec) + + def redistribute_local_tensor( local_tensor: torch.Tensor, current_spec: DTensorSpec, @@ -175,7 +181,13 @@ def redistribute_local_tensor( # which should be an empty tensor return local_tensor - transform_infos = _gen_transform_infos(current_spec, target_spec) + has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any( + isinstance(s, torch.SymInt) for s in target_spec.shape + ) + if has_symints: + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) for transform_info in transform_infos: i = transform_info.mesh_dim diff --git a/torch/distributed/_tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py similarity index 94% rename from torch/distributed/_tensor/_sharding_prop.py rename to torch/distributed/tensor/_sharding_prop.py index aefb4e41c94479..2b87d79a342b29 100644 --- a/torch/distributed/_tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import threading from functools import lru_cache from itertools import chain from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union @@ -6,7 +7,9 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode -from torch.distributed._tensor._op_schema import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( OpInfo, OpSchema, OpStrategy, @@ -17,13 +20,11 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor._utils import ( - compute_local_shape, +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, compute_local_stride, try_find_mesh_from_args, ) -from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed.device_mesh import DeviceMesh aten = torch.ops.aten @@ -37,6 +38,17 @@ def _length(obj) -> int: return len(obj) +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + class ShardingPropagator: def __init__(self) -> None: self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} @@ -46,7 +58,9 @@ def __init__(self) -> None: ] = {} # op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {} - self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign] + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) # op map to save indices of shape (and stride) args which may need to be modified in sharding prop self.op_to_shape_and_stride_idx: Dict[ OpOverload, Union[int, Tuple[int, int]] @@ -90,8 +104,7 @@ def register_op_strategy( if schema_info is not None: self.op_to_schema_info[op_overload] = schema_info - @lru_cache # noqa: B019 - def _propagate_tensor_meta( + def _propagate_tensor_meta_non_cached( self, op_schema: OpSchema ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: """ @@ -136,6 +149,12 @@ def _propagate_tensor_meta( # if fake is not a tensor or tuple of tensor, return as none return None + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + return self._propagate_tensor_meta_non_cached(op_schema) + def _wrap_output_spec_tensor_meta( self, op: OpOverload, @@ -183,7 +202,9 @@ def propagate(self, op_info: OpInfo) -> None: if op_info.schema.has_symints: output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: - output_sharding = self.propagate_op_sharding(op_info.schema) + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) op_info.output_sharding = output_sharding def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: @@ -195,7 +216,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin if op_schema.op is aten._local_scalar_dense.default: return OutputSharding(None, op_schema) - out_tensor_meta = self._propagate_tensor_meta(op_schema) + out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): @@ -468,7 +489,7 @@ def _adjust_shape_and_stride_args( expected_input_schema = list(schema.args_schema) # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred - expected_input_schema[shape_idx] = compute_local_shape( + expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( out_tensor_meta.shape, mesh, spec.placements ) diff --git a/torch/distributed/_tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py similarity index 99% rename from torch/distributed/_tensor/_shards_wrapper.py rename to torch/distributed/tensor/_shards_wrapper.py index de396473b77c91..df8c7d09e38a4a 100644 --- a/torch/distributed/_tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -309,7 +309,7 @@ def __hash__(self): # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. # pyre-fixme[3]: Return type must be annotated. - def __repr__(self): + def __repr__(self) -> str: # type: ignore[override] return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" def __str__(self) -> str: diff --git a/torch/distributed/_tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py similarity index 99% rename from torch/distributed/_tensor/_tp_conv.py rename to torch/distributed/tensor/_tp_conv.py index cc6f1968e6ef99..ac11ef2162cbb7 100644 --- a/torch/distributed/_tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -5,7 +5,7 @@ import torch import torch.distributed as dist -import torch.distributed._tensor.api as dtensor +import torch.distributed.tensor._api as dtensor aten = torch.ops.aten diff --git a/torch/distributed/_tensor/_utils.py b/torch/distributed/tensor/_utils.py similarity index 90% rename from torch/distributed/_tensor/_utils.py rename to torch/distributed/tensor/_utils.py index b7c42f094be9e6..182d7f07552751 100644 --- a/torch/distributed/_tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,49 +1,17 @@ from typing import cast, List, Sequence, Tuple import torch -import torch.distributed._tensor.api as dtensor +import torch.distributed.tensor._api as dtensor from torch._prims_common import ShapeType -from torch.distributed._tensor.placement_types import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import ( _StridedShard, - DTensorSpec, Partial, Placement, Replicate, Shard, ) -from torch.distributed.device_mesh import DeviceMesh - - -# TODO: audit existing code base to see if we can safely remove this API. -def compute_local_shape( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] -) -> Tuple[int, ...]: - """ - Compute the shape of a local shard of the given DTensor on its current - coordinate of the mesh. - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: - # if rank not in the mesh, return empty shape - return (0,) - else: - local_shape = list(global_shape) # start with global shape - ndim = len(global_shape) - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if isinstance(placement, Shard): - shard_dim = placement.dim - assert ( - shard_dim < ndim - ), f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" - local_shard_size, _ = placement._local_shard_size_on_dim( - local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] - ) - assert isinstance(local_shard_size, int) - local_shape[shard_dim] = local_shard_size - - return tuple(local_shape) def compute_local_shape_and_global_offset( @@ -90,7 +58,7 @@ def compute_local_shape_and_global_offset( if my_coordinate is None: # if rank not in the mesh, return empty offset - return ((), ()) + return ((0,), ()) else: local_shape = list(global_shape) global_offset = [0] * len(global_shape) diff --git a/torch/distributed/_tensor/debug/__init__.py b/torch/distributed/tensor/debug/__init__.py similarity index 58% rename from torch/distributed/_tensor/debug/__init__.py rename to torch/distributed/tensor/debug/__init__.py index 7ad08c376028dd..e5bf3b833fe478 100644 --- a/torch/distributed/_tensor/debug/__init__.py +++ b/torch/distributed/tensor/debug/__init__.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs -from torch.distributed._tensor.debug.comm_mode import CommDebugMode -from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding +from torch.distributed.tensor.debug._comm_mode import CommDebugMode +from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding __all__ = ["CommDebugMode", "visualize_sharding"] @@ -12,8 +12,13 @@ def _get_sharding_prop_cache_info(): This would return a named tuple showing hits, misses, maxsize and cursize of the sharding propagator cache. """ - from torch.distributed._tensor.api import DTensor + from torch.distributed.tensor._api import DTensor return ( DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] ) + + +# Set namespace for exposed private names +CommDebugMode.__module__ = "torch.distributed.tensor.debug" +visualize_sharding.__module__ = "torch.distributed.tensor.debug" diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py similarity index 96% rename from torch/distributed/_tensor/debug/comm_mode.py rename to torch/distributed/tensor/debug/_comm_mode.py index f4632fdf691ed8..6a776c88cf0b82 100644 --- a/torch/distributed/_tensor/debug/comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -10,8 +10,8 @@ import torch.nn from torch._guards import detect_fake_mode from torch.autograd.graph import register_multi_grad_hook -from torch.distributed._tensor.api import DTensor from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor._api import DTensor from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, @@ -69,7 +69,7 @@ } -class CommModeModuleTracker(ModTracker): +class _CommModeModuleTracker(ModTracker): """ Inherits ModuleTracker and expands on its functionality to track the parameters and sharding information of a model at a module-level @@ -222,12 +222,11 @@ def print_sharding_info(self): class CommDebugMode(TorchDispatchMode): """ - ``CommDebugMode`` is a context manager that counts the number of + :class:`CommDebugMode` is a context manager that counts the number of functional collectives within its context. It does this using a ``TorchDispatchMode``. - NOTE: this mode only works for functional collective atm and the - distributed_c10d collectives are not supported yet. + .. note: Not all collectives are supported yet. Example usage @@ -237,7 +236,7 @@ class CommDebugMode(TorchDispatchMode): comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() - + print(comm_mode.get_comm_counts()) """ def __init__(self): @@ -250,7 +249,7 @@ def __init__(self): self.comm_registry.add(py_op) self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) - self.advanced_module_tracker = CommModeModuleTracker() + self.advanced_module_tracker = _CommModeModuleTracker() def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): """ @@ -266,7 +265,7 @@ def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): include_module_data, include_ops, include_trivial_ops, - ) = self.set_noise_parameters(noise_level) + ) = self._set_noise_parameters(noise_level) # recursively builds json data def add_json_information(json_dict, fqn): @@ -321,7 +320,9 @@ def add_json_information(json_dict, fqn): forward_operations, backward_operations, checkpointing_operations, - ) = self.get_operations_list(self.comm_module_operation_counts[fqn]) + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) # remove all operations who don't have DTensor inputs if not include_ops: @@ -414,7 +415,7 @@ def generate_comm_debug_tracing_table(self, noise_level=3): include_module_data, include_ops, include_trivial_ops, - ) = self.set_noise_parameters(noise_level) + ) = self._set_noise_parameters(noise_level) table = "" for fqn in self.advanced_module_tracker.module_helper_dict: @@ -469,7 +470,9 @@ def generate_comm_debug_tracing_table(self, noise_level=3): forward_operations, backward_operations, checkpointing_operations, - ) = self.get_operations_list(self.comm_module_operation_counts[fqn]) + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) def add_tracing_information(table, collectives_dict, operation_list): """ @@ -544,7 +547,7 @@ def add_operations( return table - def get_operations_list(self, module_operation_counts): + def _get_operations_list(self, module_operation_counts): forward_operations = [ op for op in module_operation_counts["operations_list"] if not op["is_bw"] ] @@ -572,12 +575,6 @@ def get_comm_counts(self) -> Dict[Any, int]: """ return self.comm_counts - def get_comm_module_counts(self) -> Dict[str, Dict[Any, int]]: - """ - Returns the communication counts at a module level as a dictionary. - """ - return self.comm_module_counts - def get_parameter_info(self) -> Dict[str, Dict[str, Any]]: return self.advanced_module_tracker.module_parameters_dict @@ -613,13 +610,7 @@ def log_comm_debug_tracing_table_to_file( with open(file_name, "w") as log_file: log_file.write(table) - def print_paramater_info(self): - self.advanced_module_tracker.print_paramater_info() - - def print_sharding_info(self): - self.advanced_module_tracker.print_sharding_info() - - def set_noise_parameters(self, noise_level): + def _set_noise_parameters(self, noise_level): """ sets variables controlling what information displays based on noise level """ diff --git a/torch/distributed/_tensor/debug/_op_coverage.py b/torch/distributed/tensor/debug/_op_coverage.py similarity index 98% rename from torch/distributed/_tensor/debug/_op_coverage.py rename to torch/distributed/tensor/debug/_op_coverage.py index 214c4f003ff2d6..258dc27a6c43db 100644 --- a/torch/distributed/_tensor/debug/_op_coverage.py +++ b/torch/distributed/tensor/debug/_op_coverage.py @@ -8,7 +8,7 @@ from functorch.compile import make_boxed_func from torch._functorch.compilers import aot_module from torch._inductor.decomposition import select_decomp_table -from torch.distributed._tensor import DTensor +from torch.distributed.tensor import DTensor inductor_decomps = select_decomp_table() diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/tensor/debug/_visualize_sharding.py similarity index 95% rename from torch/distributed/_tensor/debug/visualize_sharding.py rename to torch/distributed/tensor/debug/_visualize_sharding.py index c080fa0020e265..ade935d5133cf1 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/tensor/debug/_visualize_sharding.py @@ -4,8 +4,8 @@ import numpy as np from torch._prims_common import ShapeType -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.placement_types import Placement, Shard +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Placement, Shard __all__ = ["visualize_sharding"] @@ -135,10 +135,9 @@ def _compute_local_shape_and_global_offset( def visualize_sharding(dtensor, header=""): """ - Visualizes sharding in 1D-2D dtensors - Requires tabulate, install with `pip install tabulate` + Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D. - note: no sharding info will be printed for empty tensors + .. note:: This requires the ``tabulate`` package. No sharding info will be printed for empty tensors """ if dtensor.numel() == 0: # we do not print for empty dtensors return diff --git a/torch/distributed/_tensor/debug/comm_mode_broswer_visual.js b/torch/distributed/tensor/debug/comm_mode_broswer_visual.js similarity index 100% rename from torch/distributed/_tensor/debug/comm_mode_broswer_visual.js rename to torch/distributed/tensor/debug/comm_mode_broswer_visual.js diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/tensor/device_mesh.py similarity index 100% rename from torch/distributed/_tensor/device_mesh.py rename to torch/distributed/tensor/device_mesh.py diff --git a/torch/distributed/_tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py similarity index 99% rename from torch/distributed/_tensor/examples/comm_mode_features_example.py rename to torch/distributed/tensor/examples/comm_mode_features_example.py index b98a8a3962c4a2..98143973145330 100644 --- a/torch/distributed/_tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.debug import CommDebugMode +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -115,7 +115,7 @@ def example_MLP_distributed_sharding_display(self) -> None: output_tp = model(inp) output_tp.sum().backward() - comm_mode.print_sharding_info() + print(comm_mode.get_sharding_info()) def example_MLPStacked_distributed_sharding_display(self) -> None: """ @@ -152,7 +152,7 @@ def example_MLPStacked_distributed_sharding_display(self) -> None: output_tp = model(inp) output_tp.sum().backward() - comm_mode.print_sharding_info() + print(comm_mode.get_sharding_info()) def example_MLP_module_tracing(self) -> None: """ diff --git a/torch/distributed/_tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py similarity index 99% rename from torch/distributed/_tensor/examples/convnext_example.py rename to torch/distributed/tensor/examples/convnext_example.py index 14abad0033460f..57d7bca8cc08bc 100644 --- a/torch/distributed/_tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -12,7 +12,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._tensor import ( +from torch.distributed.tensor import ( DeviceMesh, distribute_module, distribute_tensor, diff --git a/torch/distributed/_tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py similarity index 98% rename from torch/distributed/_tensor/examples/torchrec_sharding_example.py rename to torch/distributed/tensor/examples/torchrec_sharding_example.py index 33f8c7017f5be4..9e6f4054e292b4 100644 --- a/torch/distributed/_tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -9,23 +9,23 @@ from typing import List, TYPE_CHECKING import torch -from torch.distributed._tensor import ( +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.tensor import ( DeviceMesh, DTensor, init_device_mesh, Replicate, Shard, ) -from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding -from torch.distributed.checkpoint.metadata import ( - ChunkStorageMetadata, - TensorProperties, - TensorStorageMetadata, -) +from torch.distributed.tensor.debug import visualize_sharding if TYPE_CHECKING: - from torch.distributed._tensor.placement_types import Placement + from torch.distributed.tensor.placement_types import Placement def get_device_type(): diff --git a/torch/distributed/_tensor/examples/visualize_sharding_example.py b/torch/distributed/tensor/examples/visualize_sharding_example.py similarity index 92% rename from torch/distributed/_tensor/examples/visualize_sharding_example.py rename to torch/distributed/tensor/examples/visualize_sharding_example.py index 5494316428b594..71cad75ef95f81 100644 --- a/torch/distributed/_tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/tensor/examples/visualize_sharding_example.py @@ -6,8 +6,8 @@ import os import torch -from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard -from torch.distributed._tensor.debug.visualize_sharding import visualize_sharding +from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard +from torch.distributed.tensor.debug import visualize_sharding world_size = int(os.environ["WORLD_SIZE"]) diff --git a/torch/distributed/tensor/experimental/__init__.py b/torch/distributed/tensor/experimental/__init__.py new file mode 100644 index 00000000000000..5193034770af13 --- /dev/null +++ b/torch/distributed/tensor/experimental/__init__.py @@ -0,0 +1,32 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from contextlib import contextmanager + +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor.experimental._func_map import local_map +from torch.distributed.tensor.experimental._register_sharding import register_sharding + + +__all__ = ["implicit_replication", "local_map", "register_sharding"] + + +@contextmanager +def implicit_replication(): + """ + This context manager allows :class:`DTensor` to implicitly treat all non-DTensors (``torch.Tensor``) + in the program be replicate :class:`DTensor` s during the operator computation. + + .. warning:: This might possible lead to incorrect results if ``torch.Tensor`` s are not replicated + in practice, please use it at your discretion. + """ + try: + DTensor._op_dispatcher._allow_implicit_replication = True + yield + finally: + DTensor._op_dispatcher._allow_implicit_replication = False + + +# Set namespace for exposed private names +implicit_replication.__module__ = "torch.distributed.tensor.experimental" +local_map.__module__ = "torch.distributed.tensor.experimental" +register_sharding.__module__ = "torch.distributed.tensor.experimental" diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/tensor/experimental/_attention.py similarity index 65% rename from torch/distributed/_tensor/experimental/attention.py rename to torch/distributed/tensor/experimental/_attention.py index 7ed404872835aa..a00c92d9bba16c 100644 --- a/torch/distributed/_tensor/experimental/attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -5,6 +5,8 @@ import logging import types import weakref +from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from typing import ( Any, @@ -24,8 +26,8 @@ import torch.distributed._functional_collectives as ft_c import torch.nn.functional as F from torch import nn -from torch.distributed._tensor import distribute_module, DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard from torch.distributed.tensor.parallel.style import ParallelStyle @@ -34,10 +36,18 @@ aten = torch.ops.aten logger = logging.getLogger(__name__) -# Whether to upcast parameters and gradients to float32 to avoid accumulation -# errors. It is likely this is always True but we currently keep this variable -# for the experimental purpose. -_convert_to_f32 = True + + +@dataclass +class _ContextParallelOptions: + # Whether to upcast parameters and gradients to float32 to avoid accumulation + # errors. It is likely this is always True but we currently keep this variable + # for the experimental purpose. + convert_to_f32: bool = True + enable_load_balance = True + + +_cp_options = _ContextParallelOptions() class _CausalBehavior(Enum): @@ -60,7 +70,7 @@ def _is_causal_behavior( return _CausalBehavior.IS_CAUSAL source_rank = (rank - i) % world_size - if source_rank < rank: + if source_rank < rank or _cp_options.enable_load_balance: return _CausalBehavior.NOT_IS_CAUSAL else: return _CausalBehavior.SKIP @@ -76,31 +86,91 @@ def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: return tensor +def _partial_update( + original: torch.Tensor, + new: torch.Tensor, + dim: int, + n_chunks: int, + idx: int, + add: bool, +) -> torch.Tensor: + """ + This API partially update a chunk of ``original`` tensor. The ``original`` + tensor will be first chunked along ``dim`` dimension then the ``idx`` chunk + will be updated with ``new``. If ``add`` is True, the chunk will be added + with ``new``, otherwise the chunk with be replaced by ``add``. + + The result is a tensor that is the same size as ``original``. + """ + chunks = list(original.chunk(n_chunks, dim=dim)) + assert chunks[idx].shape == new.shape, (original.shape, new.shape, idx) + if add: + chunks[idx] += new + else: + chunks[idx] = new + return torch.cat(chunks, dim=dim) + + class _SDPAMerger: """A class to help to merge the local SDPA result.""" - def __init__(self, convert_to_f32: bool): + def __init__(self, convert_to_f32: bool, seq_dim: int): + self._seq_dim = seq_dim self._out: Optional[torch.Tensor] = None self._lse: Optional[torch.Tensor] = None self._convert_to_f32 = convert_to_f32 self._out_dtype = torch.float32 self._lse_dtype = torch.float32 - def _merge_one(self, block_out: torch.Tensor, block_lse: torch.Tensor) -> None: + def _merge_one( + self, block_out: torch.Tensor, block_lse: torch.Tensor, partial: bool + ) -> None: block_lse = block_lse.unsqueeze(dim=-1) if self._lse is None: self._lse = block_lse self._out = block_out else: + ROUND_ROBIN_CYCLE = 2 + assert self._lse is not None + assert self._out is not None + lse = ( + self._lse.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._lse + ) + out = ( + self._out.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._out + ) + # The algorithm from # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 # gives a relatively stable result. - self._out = self._out - F.sigmoid(block_lse - self._lse) * ( - self._out - block_out - ) - self._lse = self._lse - F.logsigmoid(self._lse - block_lse) + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + if partial: + self._lse = _partial_update( + self._lse, + lse, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + self._out = _partial_update( + self._out, + out, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + else: + self._lse = lse + self._out = out - def step(self, out: torch.Tensor, lse: torch.Tensor) -> None: + def step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool) -> None: self._out_dtype = out.dtype self._lse_dtype = lse.dtype @@ -108,7 +178,7 @@ def step(self, out: torch.Tensor, lse: torch.Tensor) -> None: out = out.to(torch.float32) lse = lse.to(torch.float32) - self._merge_one(out, lse) + self._merge_one(out, lse, partial) def results(self) -> Tuple[torch.Tensor, torch.Tensor]: assert self._out is not None @@ -131,8 +201,10 @@ def _scaled_dot_product_ring_flash_attention( if return_debug_mask: raise NotImplementedError("return_debug_mask is not supported yet") + seq_dim = 2 return _templated_ring_attention( mesh, + seq_dim, aten._scaled_dot_product_flash_attention, query=query, key=key, @@ -160,8 +232,10 @@ def _scaled_dot_product_ring_efficient_attention( if not compute_log_sumexp: raise NotImplementedError("compute_log_sumexp must be set") + seq_dim = 2 return _templated_ring_attention( mesh, + seq_dim, aten._scaled_dot_product_efficient_attention, query=query, key=key, @@ -188,6 +262,7 @@ def __call__( def _ring_rotate( block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool ) -> torch.Tensor: + block = block.contiguous() size = dist.get_world_size(pg) dsts = ( list(range(1, size)) + [0] @@ -199,6 +274,7 @@ def _ring_rotate( def _templated_ring_attention( mesh: DeviceMesh, + seq_dim: int, op: _AttentionOp, query: torch.Tensor, key: torch.Tensor, @@ -209,6 +285,61 @@ def _templated_ring_attention( """ This is a generalized ring attention implementation that can support multiple attention ops. + Note [Context parallelism load balance algorithm for causal masking] + ===================== + This explanation uses an example to illustrate the CP algorithm with causal + masking. + + Consider a scenario where the sequence length of q, k, and v is 4 (e.g., + q = (q0, q1, q2, q3)), and there are two ranks. For simplicity, we will discuss + only q and k, as v follows the same pattern as k. + + The diagram below represents a complete QK^T operation without parallelism. + The `****` entries indicate that the result is not required due to causal + masking (e.g., q0k1 is marked as `****`). + + +----+------------------------+ + | | k0 k1 k2 k3 | + +----+------------------------+ + | q0 | q0k0, ****, ****, **** | + | q1 | q1k0, q1k1, ****, **** | + | q2 | q2k0, q2k1, q2k2, **** | + | q3 | q3k0, q3k1, q3k2, q3k3 | + +----+------------------------+ + + ### No Load Balance: + + In this scenario, each rank owns a local chunk of q, k, and v, with each chunk + containing two elements. Rank0 is responsible for managing (q0, q1) and (k0, k1), + while rank1 manages (q2, q3) and (k2, k3). + + First Iteration: Both rank0 and rank1 perform SDPA with their local qkv pairs. + Causal masking is enabled as some results are not required (e.g., q0k1). + + Second Iteration: Local queries remain the same, but local kv pairs are exchanged. + Rank0 now has (q0, q1) and (k2, k3); rank1 has (q2, q3) and (k0, k1). Rank0 performs + no computation, while rank1 computes locally without causal masking since all results + (q2k0, q2k1, q3k0, q3k1) are needed. + + ### Round-robin Load Balance: + + In this setup, each rank owns two local chunks of q, k, and v, with each chunk + containing one element. Rank0 manages (q0, q3) and (k0, k3); Rank1 manages (q1, q2) + and (k1, k2). Although the local chunks are not consecutive, they are concatenated to + enable SDPA to be performed in a single call for each step. Consequently, the chunk() + function may be required to prepare the correct q, k, and v configurations. + + First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the + no-load-balance case. This iteration corresponds to the `if` of the + (`if, `elif`, `else`) in the implemementation. + + Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and + (k0, k3). For rank0, no computation is needed for q0. However, computations for + q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the + `else` of the (`if`, `elif`, `else`) in the implemementation. + For rank1, k0 is not needed for q1 and q2, so only k3 is used for SDPA. This + corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. + Parameters ---------- op: @@ -246,20 +377,21 @@ def _templated_ring_attention( key = key.contiguous() value = value.contiguous() - sdpa_merger = _SDPAMerger(_convert_to_f32) + sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim) rest: List[Any] out: torch.Tensor logsumexp: torch.Tensor for i in range(size): - # overlap communication with compute if next_kv is not None: + # Wait for the kv from the (cp_rank - 1) rank. next_kv = _maybe_wait(next_kv) key = next_kv[: key.numel()].reshape(key.shape) value = next_kv[key.numel() :].reshape(value.shape) if i < (size - 1): + # Send the k, v to the next rank next_kv = torch.cat([key.flatten(), value.flatten()]) next_kv = _ring_rotate(next_kv, pg, send_to_next=True) @@ -267,16 +399,44 @@ def _templated_ring_attention( rank=rank, world_size=size, i=i, is_causal=is_causal ) - if is_causal_behavior != _CausalBehavior.SKIP: - out, logsumexp, *rest = op( + # For a detailed understanding of the load balancing algorithm, see + # Note [Context parallelism load balance algorithm for causal masking] + if is_causal_behavior == _CausalBehavior.SKIP: + # If i > rank and load balancing is not turned on. + continue + + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # When local balance is enabled, we still need to do SDPA with + # the both local chunks of q, k, v for the first iteration. + q, k, v, partial = (query, key, value, False) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SPDA, with only the first local chunk of the k, v. + # Note that q, k, v, each contains two local chunks. + ROUND_ROBIN_CYCLE = 2 + q, k, v, partial = ( query, - key, - value, - is_causal=is_causal_behavior.value, - **kwargs, + key.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + value.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + False, ) - - sdpa_merger.step(out, logsumexp) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SPDA with only the second half of the q, and update + # only the the second part of logsumexp. So partial is True. + # Note that q, k, v, each contains two chunks. + q, k, v, partial = query.chunk(2, dim=2)[1], key, value, True + + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + out, logsumexp, *rest = op( + q, + k, + v, + is_causal=is_causal_behavior.value, + **kwargs, + ) + sdpa_merger.step(out, logsumexp, partial) return *sdpa_merger.results(), *rest @@ -358,6 +518,7 @@ def _sdpa_backward_handler( def _templated_ring_attention_backward( mesh: DeviceMesh, + seq_dim: int, op: _AttentionOp, grad_out: torch.Tensor, grad_out_name: str, @@ -369,6 +530,7 @@ def _templated_ring_attention_backward( is_causal: bool, **kwargs: Any, ) -> Tuple[torch.Tensor, ...]: + """This API implements the backward of the ring attention.""" pg = mesh.get_group() assert isinstance(pg, dist.ProcessGroup), "must be single dimension" rank = dist.get_rank(pg) @@ -378,7 +540,7 @@ def _templated_ring_attention_backward( rest: List[Any] grad_query_, grad_key_, grad_value_ = None, None, None - accum_dtype = torch.float32 if _convert_to_f32 else query.dtype + accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype grad_query = torch.zeros_like(query, dtype=accum_dtype) grad_key = torch.zeros_like(key, dtype=accum_dtype) grad_value = torch.zeros_like(value, dtype=accum_dtype) @@ -387,6 +549,7 @@ def _templated_ring_attention_backward( value = value.contiguous() for i in range(size): if next_kv is not None: + # Wait for the kv from the (cp_rank - 1) rank. buffer = _maybe_wait(next_kv) pointer = 0 key = buffer[pointer : pointer + key.numel()].reshape(key.shape) @@ -395,6 +558,7 @@ def _templated_ring_attention_backward( pointer += value.numel() if i != size - 1: + # Send the kv to the next rank. next_kv = torch.cat([key.flatten(), value.flatten()]) next_kv = _ring_rotate(next_kv, pg, send_to_next=True) @@ -403,13 +567,45 @@ def _templated_ring_attention_backward( ) if is_causal_behavior != _CausalBehavior.SKIP: - kwargs[grad_out_name] = grad_out + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # We need to do SDPA with the full local q, k, v. + q, k, v, out_, dout, lse = (query, key, value, out, grad_out, logsumexp) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SPDA with only the first half of the k, v. + # Note that q, k, v, each contains two chunks. + q, k, v, out_, dout, lse = ( + query, + key.chunk(2, dim=seq_dim)[0], + value.chunk(2, dim=seq_dim)[0], + out, + grad_out, + logsumexp, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SPDA with only the second half of the q + # Note that q, k, v, each contains two chunks. + q, k, v, out_, dout, lse = ( + query.chunk(2, dim=seq_dim)[1], + key, + value, + out.chunk(2, dim=seq_dim)[1], + grad_out.chunk(2, dim=seq_dim)[1], + # Need to make logsumexp contiguous, otherwise there will + # be numerical error. + logsumexp.chunk(2, dim=seq_dim)[1].contiguous(), + ) + + kwargs[grad_out_name] = dout + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. grad_query_, grad_key_, grad_value_, *rest = op( - query=query, - key=key, - value=value, - out=out, - logsumexp=logsumexp, + query=q, + key=k, + value=v, + out=out_, + logsumexp=lse, is_causal=is_causal_behavior.value, **kwargs, ) @@ -418,10 +614,14 @@ def _templated_ring_attention_backward( grad_key_ = torch.zeros_like(key, dtype=accum_dtype) grad_value_ = torch.zeros_like(value, dtype=accum_dtype) - # Get the grad key and grad value for the i round. - if i > 0: + ROUND_ROBIN_CYCLE = 2 + if i == 0: + grad_key += grad_key_ + grad_value += grad_value_ + else: pointer = 0 assert next_grad_kv is not None + # Wait for the kv gradient from (cp_rank - 1) rank. next_grad_kv = _maybe_wait(next_grad_kv) grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( grad_key.shape @@ -431,13 +631,42 @@ def _templated_ring_attention_backward( grad_value.shape ) - grad_key += grad_key_ - grad_value += grad_value_ + if i <= rank and _cp_options.enable_load_balance: + grad_key = _partial_update( + grad_key, + grad_key_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + grad_value = _partial_update( + grad_value, + grad_value_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + else: + grad_key += grad_key_ + grad_value += grad_value_ - # Send the key, value, grad key, and grad value to the next rank. next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) + # Send the grad key, and grad value to the next rank. next_grad_kv = _ring_rotate(next_grad_kv, pg, send_to_next=True) - grad_query += grad_query_ + + if i <= rank or not _cp_options.enable_load_balance: + grad_query += grad_query_ + else: + grad_query = _partial_update( + grad_query, + grad_query_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=True, + ) assert next_grad_kv is not None assert grad_key_ is not None @@ -473,8 +702,10 @@ def _scaled_dot_product_ring_flash_attention_backward( *, scale: Optional[float] = None, ) -> Tuple[torch.Tensor, ...]: + seq_dim = 2 return _templated_ring_attention_backward( mesh, + seq_dim, aten._scaled_dot_product_flash_attention_backward.default, grad_out=grad_out, grad_out_name="grad_out", @@ -512,8 +743,10 @@ def _scaled_dot_product_ring_efficient_attention_backward( *, scale: Optional[float] = None, ) -> Tuple[torch.Tensor, ...]: + seq_dim = 2 return _templated_ring_attention_backward( mesh, + seq_dim, aten._scaled_dot_product_efficient_attention_backward.default, grad_out=grad_out, grad_out_name="grad_out_", @@ -779,10 +1012,94 @@ def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any: _restore_function(F.scaled_dot_product_attention, F) -def _get_sequence_shard( - buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int -) -> torch.Tensor: - return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] +class _LoadBalancer(ABC): + @classmethod + @abstractmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + ... + + @classmethod + @abstractmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + ... + + +class _SequentialSharder(_LoadBalancer): + """ + This load balancer chunks the buffer into cp_world_size and rank0 gets + 0th shard, rank1 gets 1st shard, ... + So this doesn't have any load balancing effect when using the causal masking. + """ + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert buffer.size()[seq_dim] % mesh.size() == 0 + return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + buffer = buffer.contiguous() + all_buffers = [torch.empty_like(buffer) for _ in range(mesh.size())] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + return torch.cat(all_buffers, dim=seq_dim) + + +class _RoundRobinLoadBalancer(_LoadBalancer): + """ + This load balancer chunk the buffer into cp_world_size * ROUND_ROBIN_CYCLE + shards, and uses a round robin approach to achieve load balancing. + Since ROUND_ROBIN_CYCLE being 2 will achieve perfect load balancing for + causal masking, we assume ROUND_ROBIN_CYCLE is always 2 to simplify the + implementation. + """ + + ROUND_ROBIN_CYCLE = 2 + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert ( + cls.ROUND_ROBIN_CYCLE == 2 + ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + cp_world_size = mesh.size() + cp_rank = mesh.get_local_rank() + assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 + chunks = buffer.chunk(cp_world_size * 2, dim=seq_dim) + return torch.cat( + (chunks[cp_rank], chunks[cp_world_size * 2 - cp_rank - 1]), + dim=seq_dim, + ) + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert ( + cls.ROUND_ROBIN_CYCLE == 2 + ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + buffer = buffer.contiguous() + cp_world_size = mesh.size() + cp_rank = mesh.get_local_rank() + + all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + sliced_buffers = [sb for b in all_buffers for sb in b.chunk(2, dim=seq_dim)] + ordered_buffers = list(sliced_buffers) + for i, b in enumerate(sliced_buffers): + if i % 2 == 0: + ordered_buffers[i // 2] = b + else: + ordered_buffers[cp_world_size * 2 - (i // 2) - 1] = b + return torch.cat(ordered_buffers, dim=seq_dim) def _context_parallel_buffers( @@ -792,8 +1109,13 @@ def _context_parallel_buffers( ) -> List[torch.Tensor]: """Shard the buffers along the sequence dimensions according to CP rules.""" new_buffers = [] + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) for buffer, seq_dim in zip(buffers, buffer_seq_dims): - new_buffers.append(_get_sequence_shard(buffer, mesh, seq_dim)) + new_buffers.append(sharder.shard(buffer, mesh, seq_dim)) return new_buffers @@ -851,7 +1173,6 @@ def context_parallel( raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] - chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims) for buffer, chunk in zip(buffers, chunks): chunk = chunk.clone() @@ -865,3 +1186,20 @@ def context_parallel( if original_buffer is not None: buffer.resize_(original_buffer.shape) buffer.copy_(original_buffer) + + +@torch.no_grad() +def context_parallel_unshard( + mesh: DeviceMesh, + buffers: List[torch.Tensor], + seq_dims: List[int], +) -> List[torch.Tensor]: + """ + Unshard the tensors (e.g., output) that are sharded due to context parallelism. + """ + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) + return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)] diff --git a/torch/distributed/_tensor/experimental/func_map.py b/torch/distributed/tensor/experimental/_func_map.py similarity index 77% rename from torch/distributed/_tensor/experimental/func_map.py rename to torch/distributed/tensor/experimental/_func_map.py index 667fbb484fb799..23b69373e74e57 100644 --- a/torch/distributed/_tensor/experimental/func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -4,8 +4,8 @@ import torch from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor.placement_types import Placement try: @@ -16,7 +16,6 @@ __all__ = ["local_map"] - PlacementType = Optional[Sequence[Placement]] InputPlacements = Optional[Tuple[PlacementType, ...]] OutputPlacements = Union[PlacementType, Tuple[PlacementType, ...]] @@ -31,32 +30,34 @@ def local_map( redistribute_inputs: bool = False, ): """ - ``local_map`` is an experimental API that allows users to apply on :class:`DTensors` - a function that is written to be applied on :class:`~torch.Tensors`. + :meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s + to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting + the local components of :class:`DTensor`, call the function, and wrap the outputs to + :class:`DTensor` according to the ``out_placements``. Args: func (Callable): the function to be applied on each local shard of - :class:`DTensor`s. + :class:`DTensor` s. out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): - the desired placements of the :class:`DTensor`s in ``func``'s flattened output. + the desired placements of the :class:`DTensor` s in ``func``'s flattened output. If the flattened ``output`` is a single value, the ``out_placements`` should be of type `PlacementType`. Otherwise if the flattened ``output`` has multiple values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1 mapping to the flattened ``output``. Besides, for :class:`Tensor` output, we use `PlacementType` as its - placements (a `Tuple[Placement]` value). For non-:class:`Tensor` output, - the `PlacementType` should be `None`. + placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType` + should be `None`. Note that the only exception is when no :class:`DTensor` argument is passed in. In this case, even if `out_placements` is not `None`, the result function - should ignore the desired placements because the application is not on - :class:`DTensors`. + should ignore the desired placements because the function is not running with + :class:`DTensor` s. in_placements (Tuple[`PlacementType`, ...], optional): - the required placements of the :class:`DTensor`s in ``func``'s flattened input. - If ``in_placements`` is specified, ``local_map`` would examine whether the + the required placements of the :class:`DTensor` s in the flattened inputs of ``func``. + If ``in_placements`` is specified, :meth:`local_map` would examine whether the placements of each :class:`DTensor` argument is the same as the required placements or not. If the placements are not the same and ``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if - ``redistribute_inputs`` is `True`, the argument will be first redistributed to + ``redistribute_inputs`` is ``True``, the argument will be first redistributed to the required sharding placements before passing its local tensor to ``func``. The only exception is when required placements are not ``None`` and the argument is a :class:`torch.Tensor`. In this case, the placements examination @@ -64,12 +65,12 @@ def local_map( If ``in_placements`` is ``None``, no placements examination will be performed. Default: None device_mesh (:class:`DeviceMesh`, optional): - the device mesh that all the :class:`DTensor`s are placed on. If not - specified, this will be inferred from the input :class:`DTensor`s' device - mesh. `local_map` requires every :class:`DTensor`s to be placed on the same + the device mesh that all the :class:`DTensor` s are placed on. If not + specified, this will be inferred from the input :class:`DTensor` s' device + mesh. `local_map` requires every :class:`DTensor` s to be placed on the same device mesh. Default: None. redistribute_inputs (bool, optional): - the bool value indicating whether to reshard the input :class:`DTensor`s when + the bool value indicating whether to reshard the input :class:`DTensor` s when their placements are different from the required input placements. If this value is ``False`` and some :class:`DTensor` input has a different placement, an exception will be raised. Default: False. @@ -79,16 +80,16 @@ def local_map( and returns a :class:`DTensor` constructed from the return value of ``func``. Raises: - AssertionError: If the input :class:`DTensor`s are not placed on the same device - mesh, or if they are placed on a different device mesh than the ``device_mesh`` - argument passed in. + AssertionError: If the input :class:`DTensor` is not placed on the same device + mesh, or if they are placed on a different device mesh than the ``device_mesh`` + argument passed in. - AssertionError: For any non-:class:`DTensor` output, we require its corresponding - output placement in ``out_placements`` be None. An AssertionError will be raised - if this is not the case. + AssertionError: For any non-DTensor output, we require its corresponding + output placement in ``out_placements`` be None. An AssertionError will be raised + if this is not the case. ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs - a redistribution according to ``in_placements``. + a redistribution according to ``in_placements``. Example: >>> # xdoctest: +SKIP("distributed") @@ -115,7 +116,7 @@ def local_map( >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors - NOTE: This API is currently experimental and subject to change + .. note:: This API is currently experimental and subject to change """ def wrapped(*args, **kwargs): @@ -194,13 +195,17 @@ def wrapped(*args, **kwargs): flat_out, out_spec = pytree.tree_flatten(out) flat_dist_out = [] - for idx, out in enumerate(flat_out): - spec = ( - out_placements[idx] - if isinstance(out_placements, tuple) - else out_placements - ) - + out_placements_tuple = ( + out_placements + if isinstance(out_placements, tuple) + else (out_placements,) + ) + assert len(flat_out) == len(out_placements_tuple), ( + "local_map requires one PlacementType be provided for each output value," + f" received {len(out_placements_tuple)} out_placements but" + f" {len(flat_out)} is expected!" + ) + for out, spec in zip(flat_out, out_placements_tuple): if isinstance(out, torch.Tensor): assert not isinstance( out, DTensor diff --git a/torch/distributed/_tensor/experimental/register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py similarity index 93% rename from torch/distributed/_tensor/experimental/register_sharding.py rename to torch/distributed/tensor/experimental/_register_sharding.py index a6a70c162e97fc..c526e2e0a44090 100644 --- a/torch/distributed/_tensor/experimental/register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -5,8 +5,8 @@ import torch from torch._ops import OpOverload -from torch.distributed._tensor import DeviceMesh, DTensor -from torch.distributed._tensor._op_schema import ( +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor._op_schema import ( _is_inplace_op, OpSchema, OpStrategy, @@ -15,7 +15,7 @@ StrategyType, TupleStrategy, ) -from torch.distributed._tensor.ops.utils import expand_to_full_mesh_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy __all__ = ["register_sharding"] @@ -23,8 +23,8 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]): """ - ``register_sharding`` is an experimental API that allows users to register sharding - strategies for an operator when the tensor inputs and outputs are :class:`DTensor`s. + :meth:`register_sharding` is an experimental API that allows users to register sharding + strategies for an operator when the tensor inputs and outputs are DTensor. It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``, e.g. when ``op`` is a custom operator that is not supported by :class:`DTensor`; (2) when users would like to overwrite default sharding strategies of existing operators. @@ -62,6 +62,8 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]): >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings + + .. note:: This API is currently experimental and subject to change """ def custom_strategy( diff --git a/torch/distributed/_tensor/experimental/tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py similarity index 98% rename from torch/distributed/_tensor/experimental/tp_transform.py rename to torch/distributed/tensor/experimental/_tp_transform.py index 114192c9880494..81b47c9b5a7d3b 100644 --- a/torch/distributed/_tensor/experimental/tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -5,22 +5,17 @@ import torch from torch._subclasses.fake_tensor import FakeTensor -from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor -from torch.distributed._tensor._op_schema import ( - DTensorSpec, +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( OpSchema, OutputSharding, OutputSpecType, PlacementStrategy, ) -from torch.distributed._tensor._redistribute import redistribute_local_tensor -from torch.distributed._tensor.placement_types import ( - Placement, - Replicate, - Shard, - TensorMeta, -) +from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard from torch.export import ExportedProgram from torch.export.exported_program import ExportGraphSignature from torch.fx import GraphModule diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index 2e1ebfd53ab7b3..0d1097f4328c3e 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -3,8 +3,8 @@ import torch from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor import DTensor -from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec @no_type_check diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py index 0eace5f6d78061..f50b5dd64768d0 100644 --- a/torch/distributed/tensor/parallel/_utils.py +++ b/torch/distributed/tensor/parallel/_utils.py @@ -2,9 +2,9 @@ import warnings from typing import Tuple, Union -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Placement try: diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index e0fc4d2ef2b725..db4db018cce2d3 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -3,10 +3,10 @@ from typing import Dict, Union import torch -import torch.distributed._tensor.random as random +import torch.distributed.tensor._random as random import torch.nn as nn -from torch.distributed._tensor import DeviceMesh -from torch.distributed._tensor.random import ( +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor._random import ( is_rng_supported_mesh, TensorParallelRNGTracker, ) diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 8c4a7a655a8ffd..d2faa9ed32dc5d 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -14,12 +14,12 @@ ) from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.device_mesh import _mesh_resources from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.tensor.parallel._data_parallel_utils import ( _flatten_tensor, _unflatten_tensor, diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 4e7af55d32c356..ab113246b7071c 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Tuple import torch -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard __all__ = [ @@ -39,6 +39,9 @@ def input_reshard( Return: A :class:`nn.Module` object registered with TP input resharding. """ + if input_reshard_dim is None: + return module + cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> None: @@ -56,8 +59,6 @@ def input_reshard_backward_hook( nonlocal cx cx.__exit__() # type: ignore[name-defined, union-attr] - if input_reshard_dim is None: - return module module.register_forward_pre_hook(input_reshard_forward_pre_hook) module.register_forward_hook(input_reshard_backward_hook) return module diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index 79f6f08e595393..99f1e3ad6ef9ad 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -8,15 +8,16 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch import Tensor -from torch.distributed._tensor import DTensor, Replicate, Shard -from torch.distributed._tensor.ops._embedding_ops import _MaskPartial -from torch.distributed._tensor.ops._math_ops import ( +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops._math_ops import ( _skip_dim, Reduction, replicate_reduction_dims, ) -from torch.distributed._tensor.placement_types import DTensorSpec, Placement, TensorMeta -from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import Placement aten = torch.ops.aten diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 46b19a918296c4..d1e8dd7e236cd1 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from torch.distributed._tensor import ( +from torch.distributed.tensor import ( DeviceMesh, distribute_module, distribute_tensor, @@ -14,7 +14,7 @@ Replicate, Shard, ) -from torch.distributed._tensor.placement_types import Placement +from torch.distributed.tensor.placement_types import Placement __all__ = [ diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py new file mode 100644 index 00000000000000..0d9834ab8b81cd --- /dev/null +++ b/torch/distributed/tensor/placement_types.py @@ -0,0 +1,652 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass +from typing import cast, List, Optional, Tuple + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import ( + fill_empty_tensor_to_shards, + mesh_broadcast, + mesh_scatter, + pad_tensor, + shard_dim_alltoall, + unpad_tensor, +) + + +__all__ = ["Placement", "Shard", "Replicate", "Partial"] + + +class Placement: + """ + The base class for the Placement type, where it describes how a DTensor is placed onto the + ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. + It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, + and ``Partial``. + + This class is not meant to be used directly, mainly served as a typing stub. + """ + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + is_shard_instance = isinstance(self, Shard) + if dim is not None and is_shard_instance: + return cast(Shard, self).dim == dim + else: + return is_shard_instance + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, Partial) + + +@dataclass(frozen=True) +class Shard(Placement): + """ + The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension + ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension only holds a shard/piece of the global Tensor. The + ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the + last few shards on the DeviceMesh dimension might be empty when the tensor dimension + is not evenly divisble on the DeviceMesh dimension. The ``Shard`` placement can be + used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) + + Args: + dim (int): The tensor dimension that describes the DTensor is sharded over its + corresponding DeviceMesh dimension. + + .. warning:: sharding on a tensor dimension where the tensor dimension size is not + evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + """ + + dim: int + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + """ + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + + # chunk tensor over dimension `dim` into n slices + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + num_empty_tensors = num_chunks - len(tensor_list) + + # if no need to have padding or tensor dim size is evenly sharded already + # we can return early. + if not with_padding or tensor.size(self.dim) % num_chunks == 0: + if contiguous: + tensor_list = [t.contiguous() for t in tensor_list] + return ( + fill_empty_tensor_to_shards(tensor_list, self.dim, num_empty_tensors), + [], + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + # Compute chunk size for each chunk for ``self.dim`` + chunk_sizes = [ + tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 + for idx in range(num_chunks) + ] + # Compute pad size on each chunk + pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] + + # Reuse tensor to fill empty chunk with empty tensor + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, num_empty_tensors + ) + shard_list = [] + for shard, pad_size in zip(tensor_list, pad_sizes): + # Fill the empty tensor with zeroes with padding. + if with_padding and pad_size > 0: + shard = pad_tensor(shard, self.dim, pad_size) + shard = shard.contiguous() if contiguous else shard + shard_list.append(shard) + return shard_list, pad_sizes + + @staticmethod + def _local_shard_size_on_dim( + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim + """ + # Compute the chunk size inline with ``torch.chunk`` + if size_on_dim % num_chunks == 0: + full_chunk_size = size_on_dim // num_chunks + return full_chunk_size, full_chunk_size * rank if return_offset else -1 + + # uneven sharding case + full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks + shard_starting_idx = full_chunk_size * rank + + if size_on_dim < shard_starting_idx: + return 0, size_on_dim if return_offset else -1 + else: + local_shard_size = ( + min(size_on_dim, shard_starting_idx + full_chunk_size) + - shard_starting_idx + ) + return local_shard_size, shard_starting_idx if return_offset else -1 + + def _shard_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + shard and scatter a tensor on a mesh dimension (use coordinate + 0 on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + output = torch.empty_like(scatter_list[mesh_dim_local_rank]) + mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim) + + # Only unpad if the local_tensor was padded on the dimension. + if pad_sizes and pad_sizes[mesh_dim_local_rank] > 0: + output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) + return output + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: str, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return tensor + + is_padded = tensor.size(self.dim) % num_chunks != 0 + if is_padded: + scattered_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + tensor = torch.cat(scattered_list, dim=self.dim) + elif not tensor.is_contiguous(): + tensor = tensor.contiguous() + + output = funcol.reduce_scatter_tensor( + tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) + ) + + if is_padded: + output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] + return output + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + # check if it's uneven, so we need to pad input tensor before all_gather + local_shape = list(local_tensor.size()) + + logical_dim_size = current_logical_shape[self.dim] + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + pad_size = full_chunk_size - local_shape[self.dim] + local_tensor = pad_tensor(local_tensor, self.dim, pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if is_padded: + unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] + result = unpad_tensor(result, self.dim, unpad_size) + return result + + def _replicate_to_shard( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_index: int, + ) -> torch.Tensor: + """ + transform from replicated tensor to a sharded tensor on + the current rank, which would perform a local chunk + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + shards, _ = self._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + return shards[shard_index].clone() + + def _to_new_shard_dim( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + new_shard_dim: int, + ) -> torch.Tensor: + """ + transform from existing sharded tensor to a new sharded tensor on + that shard on a new dimension, which performs an alltoall + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return local_tensor + + num_chunks = mesh.size(mesh_dim=mesh_dim) + + old_dim_logical_size = current_logical_shape[self.dim] + new_dim_logical_size = current_logical_shape[new_shard_dim] + old_dim_padding = old_dim_logical_size % num_chunks != 0 + new_dim_padding = new_dim_logical_size % num_chunks != 0 + if old_dim_padding: + old_dim_full_chunk_size = ( + old_dim_logical_size + num_chunks - 1 + ) // num_chunks + old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) + if new_dim_padding: + new_dim_full_chunk_size = ( + new_dim_logical_size + num_chunks - 1 + ) // num_chunks + new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( + new_shard_dim + ) + local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + new_tensor = shard_dim_alltoall( + local_tensor, self.dim, new_shard_dim, mesh, mesh_dim + ) + + if old_dim_padding: + old_dim_unpad_size = ( + old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] + ) + new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] + + if new_dim_padding: + local_shard_size_on_new_dim = self._local_shard_size_on_dim( + new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] + )[0] + new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] + new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] + + return new_tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __hash__(self) -> int: + return hash(self.dim) + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +# kw_only is only available in python >= 3.10 +kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} + + +@dataclass(frozen=True, **kw_only_dataclass) +class _StridedShard(Shard): + """ + _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor + is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. + We call this right-to-left sharding which is the opposite of the default + left-to-right sharding. See the example below: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [Shard(0), Shard(0)] + + The default sharding behavior shards the tensor on "dp" mesh dimension first then + "tp" dimension. The sharding result will be: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 1 (row 2-3) + 2 | (1, 0) | 2 (row 4-5) + 3 | (1, 1) | 3 (row 6-7) + + While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on + "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the + result: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The consequence is, any attempt to redistribute this DTensor to a full replica will + produce a wrong result because the shard-to-replicate redistribution always happens + right-to-left, regardless it's left-to-right sharding or right-to-left. To address + this, we use _StridedShard placement to make this right-to-left sharding compatible + with our left-to-right convention on both tensor distribution and redistribution. + + Now with _StridedShard, the right-to-left sharding above can be represented as: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [_StridedShard(0, split_factor=2), Shard(0)] + + And a left-to-right processing of `placements` will produce the same result, which is + different from using the `Shard` placement: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The argument `split_factor` is the number of existing shards over the tensor sharding + dimension before processing the _StridedShard placement, as if the sharding happened + right-to-left. In the example above, the tensor should first be sharded on the "tp" + dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the + `split_factor` of the _StridedShard placement on "dp" dim is 2. + + TODO: strided sharding needs to work fine with uneven sharding. Now it forbids + resharding if the tensor is unevenly sharded. + TODO: we should remove _StridedShard placement once we can unify it with Shard + """ + + split_factor: int + + def __eq__(self, other: object) -> bool: + if isinstance(other, _StridedShard): + return self.dim == other.dim and self.split_factor == other.split_factor + elif isinstance(other, Shard): + # TODO: this is to avoid extra all-gather in dtensor op dispatch + # note that sharding prop would not produce _StridedShard and an + # placement inequality would introduce an all-gather for resharding + return self.dim == other.dim + return False + + def __hash__(self) -> int: + return hash((self.dim, self.split_factor)) + + def __repr__(self) -> str: + """ + machine readable representation of the _StridedShard placement + """ + return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" + + def __str__(self) -> str: + """human readable representation of the _StridedShard placement""" + return f"_S({self.dim}, {self.split_factor})" + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + TODO: currently _StridedShard does not support padding + """ + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + + total_split = num_chunks * self.split_factor + assert tensor.size(self.dim) % total_split == 0, ( + "_StridedShard currently only allows even sharding but got tensor size" + f" {tensor.size(self.dim)} on dim {self.dim} and total split" + f" {total_split}={num_chunks} * {self.split_factor}" + ) + + group_size = self.split_factor + total_split_tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) + tensor_list = [ + torch.cat( + [ + total_split_tensor_list[i + j * num_chunks] # stride is num_chunks + for j in range(group_size) + ], + dim=self.dim, + ) + for i in range(num_chunks) + ] + + if contiguous: + tensor_list = [t.contiguous() for t in tensor_list] + + return tensor_list, [] + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: List[int], + ) -> torch.Tensor: + """ + Note: currently _StridedShard does not support padding + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + total_split = num_chunks * self.split_factor + # NOTE: we require Strided Sharding to be even for now + assert current_logical_shape[self.dim] % total_split == 0, ( + "_StridedShard requires even sharding but got tensor size " + f"{current_logical_shape[self.dim]} on dim {self.dim} and " + f"total split {total_split}=num_chunks {num_chunks} " + f"* split_factor {self.split_factor}" + ) + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if isinstance(result, funcol.AsyncCollectiveTensor): + result = result.wait() + + tensor_shard_list = torch.chunk(result, total_split, dim=self.dim) + # rearrange the order + new_tensor_shard_list = [] + for idx in range(len(tensor_shard_list)): + # the shard split of index `idx` is assigned a new index within + # _StridedShard._split_tensor: + # the original tensor was split into `total_split` chunks, + # all chunks with the same `idx % num_chunks` are merged into one + # new shard and placed on mesh's local rank `idx % num_chunks` + idx_after_split = idx % num_chunks * self.split_factor + idx // num_chunks + new_tensor_shard_list.append(tensor_shard_list[idx_after_split]) + + return torch.cat(new_tensor_shard_list, dim=self.dim).contiguous() + + +@dataclass(frozen=True) +class Replicate(Placement): + """ + The ``Replicate()`` placement describes the DTensor replicating on a corresponding + ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a + replica of the global Tensor. The ``Replicate`` placement can be used by all + DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) + """ + + def __eq__(self, other: object) -> bool: + return isinstance(other, Replicate) + + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + def _replicate_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + Replicate (broadcast) a torch.Tensor on a mesh dimension (use + the first coordinate on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + tensor = tensor.contiguous() + mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim) + return tensor + + +@dataclass(frozen=True) +class Partial(Placement): + """ + The ``Partial(reduce_op)`` placement describes the DTensor that is pending + reduction on a specified ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension holds the partial value of the global Tensor. User can + redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` + placement on the specified ``DeviceMesh`` dimension using ``redistribute``, + which would trigger necessary communication operations under the hood (i.e. + ``allreduce``, ``reduce_scatter``). + + Args: + reduce_op (str, optional): The reduction op to be used for the partial DTensor + to produce Replicated/Sharded DTensor. Only element-wise reduction operations + are supported, including: "sum", "avg", "product", "max", "min", default: "sum". + + .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, + and can only be used by the ``DTensor.from_local`` API. + """ + + reduce_op: str = "sum" + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #1: + # _reduce_value: reduce the value of the tensor on the mesh dimension + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # Partial placement contract #2: + # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #3: + # _partition_value: partition the value of a replicated tensor on the mesh dimension + + # _partition_value is the conjugate operation of _reduce_value + # - i.e. _partition_value on a sum reduce op is just a divison operation + # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation + # TODO: if the reduce_op is min/max, etc. the _partition_value should be a + # different operation + assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" + num_chunks = mesh.size(mesh_dim=mesh_dim) + return tensor / num_chunks + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Partial): + return False + return self.reduce_op == other.reduce_op + + def __hash__(self) -> int: + return 1 + hash(self.reduce_op) + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial({self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return "P" + + +# We keep the old _Partial name for a while for BC reason +_Partial = Partial diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index a4502a4b52e48e..faacd059f899e1 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -232,10 +232,10 @@ def apply(x): return fn(x) elif hasattr(x, "__dataclass_fields__"): dc = dataclasses.replace(x) - for f in dataclasses.fields(dc): - name = f.name - setattr(dc, name, apply(getattr(dc, name))) - return dc + changes = { + f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) + } + return dataclasses.replace(dc, **changes) elif isinstance(x, OrderedDict): od = x.__class__() for key, value in x.items(): diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 3c510bd32abc62..e6f730b123afd1 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -204,7 +204,7 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, fn): + def __call__(self, fn): # type: ignore[override] """ Support for syntax to customize static attributes:: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index bd8fa87f2619a4..a4d403383d9cb0 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -177,7 +177,7 @@ def sample(self, sample_shape=torch.Size()): self._loc, self._concentration, self._proposal_r, x ).to(self.loc.dtype) - def expand(self, batch_shape): + def expand(self, batch_shape, _instance=None): try: return super().expand(batch_shape) except NotImplementedError: diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 1e365a990c8290..336b36424f31ce 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch._ops import OpOverload from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint @@ -49,8 +50,11 @@ "ExportedProgram", "ModuleCallEntry", "ModuleCallSignature", + "core_aten_decompositions", "dims", "export", + "export_for_training", + "export_for_inference", "load", "register_dataclass", "save", @@ -61,7 +65,12 @@ from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection -from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature +from .exported_program import ( + core_aten_decompositions, + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) from .graph_signature import ExportBackwardSignature, ExportGraphSignature from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule @@ -69,6 +78,186 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +def export_for_training( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the all ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. This API is intended for PT2 quantization training use cases + and will soon be the default IR of torch.export.export in the near future. To read further about + the motivation behind this change, please refer to + https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 + With this API, and :func:`run_decompositions()`, you should be able to get inference IR with + your custom decomposition behaviour. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export_for_training + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + return _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + +def export_for_inference( + mod: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, + strict: bool = True, + preserve_module_call_signature: Tuple[str, ...] = (), + decomp_table: Optional[Dict["OpOverload", Optional[Callable]]] = None, +) -> ExportedProgram: + """ + :func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the ATen operator set + (as well as any user-specified custom operators) which is customizable via decomp_table, + (2) has eliminated all Python control flow and data structures (with certain exceptions), + and (3) records the set of shape constraints needed to show that this normalization and control-flow + elimination is sound for future inputs. This API is for convenience use as it combines :func:`export_for_training` and + :func:`run_decompositions`. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + decomp_table: See :func:`run_decompositions` for more details. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + + ep_for_training = export_for_training( + mod, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + return ep_for_training.run_decompositions(decomp_table=decomp_table) + + def export( mod: torch.nn.Module, args: Tuple[Any, ...], @@ -79,8 +268,7 @@ def export( preserve_module_call_signature: Tuple[str, ...] = (), ) -> ExportedProgram: """ - :func:`export` takes an arbitrary Python callable (an nn.Module, a function or - a method) along with example inputs, and produces a traced graph representing + :func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different inputs or serialized. The traced graph (1) produces normalized operators in the functional ATen operator set @@ -169,6 +357,12 @@ def export( raise ValueError( f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) return _export( mod, args, diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index 930915f96f9bd1..683e89c3d14910 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -5,64 +5,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import operator -from typing import List import torch from torch._higher_order_ops.auto_functionalize import ( auto_functionalized, - get_mutable_arg_names, + auto_functionalized_v2, ) +from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized from torch.export import ExportedProgram -def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes): - # Update every use of the HOP - for node in reversed(auto_functionalize_nodes): - func = node.args[0] - original_kwargs = node.kwargs - assert isinstance(func, torch._ops.OpOverload) - - with ep.graph.inserting_before(node): - # This makes the call_function refer to every arg as a kwarg, this is weird but probably fine? - new_node = ep.graph.call_function(func, kwargs=node.kwargs) - for k, v in node.meta.items(): - new_node.meta[k] = v - - # Replace auto_functionalize(func, args) with just func(args) - node.replace_all_uses_with(new_node) - - mutable_args_names = get_mutable_arg_names(new_node.target) - - # update the users of the auto_func node (the getitem nodes) - for user in list(new_node.users.keys()): - assert user.target == operator.getitem - # getitem corresponding to a mutated input, just replace all uses with the original input - if user.args[1] >= len(func._schema.returns): - assert user.args[1] <= len(func._schema.returns) + len( - mutable_args_names - ) - - # If the result of getitem was used in an output node, update the output spec with the correct name - adjusted_index = user.args[1] - len(func._schema.returns) - original_arg = original_kwargs[mutable_args_names[adjusted_index]] - - # This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order - # of the getitem calls following the HOP. - user.replace_all_uses_with(original_arg) - - if len(func._schema.returns) == 1: - # If the function has 1 return then it will just directly return the - # result -- we don't need a getitem. So we can replace all the - # getitem(auto_functionalized, 0) with just the note itself. - for user in list(new_node.users.keys()): - if user.args[1] == 0: - user.replace_all_uses_with(new_node) - - new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)] - ep.graph.erase_node(node) - - ep.graph.eliminate_dead_code() +def remove_self_clone(graph: torch.fx.Graph): + for node in graph.nodes: + if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) def unsafe_remove_auto_functionalized_pass( @@ -73,15 +30,23 @@ def unsafe_remove_auto_functionalized_pass( and modifies the calling EP inplace to have the original mutator op. This pass doesn't perform safety checks to make sure that this inplace mutation is safe. """ - auto_functionalize_nodes: List[torch.fx.Node] = [] - for module in ep.graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - for node in ep.graph.nodes: - if node.op == "call_function" and node.target is auto_functionalized: - auto_functionalize_nodes.append(node) with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): - _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes) + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in ep.graph.nodes: + if ( + node.op == "call_function" and node.target is auto_functionalized + ) or ( + node.op == "call_function" and node.target is auto_functionalized_v2 + ): + func = node.args[0] + assert isinstance(func, torch._ops.OpOverload) + # re-inplace everything + node.meta["only_clone_these_tensors"] = [] + decompose_auto_functionalized(ep.graph) + remove_self_clone(ep.graph) + ep.graph.eliminate_dead_code() return ep diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 66c34d1047a89d..abd2b5405fb937 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -38,7 +38,13 @@ lift_constants_pass, rewrite_script_object_meta, ) -from torch._export.utils import placeholder_naming_pass, placeholder_prefixes +from torch._export.utils import ( + _collect_param_buffer_metadata, + _get_shape_env_from_gm, + _populate_param_buffer_metadata_to_new_gm, + placeholder_naming_pass, + placeholder_prefixes, +) from torch._export.verifier import SpecViolationError from torch._export.wrappers import _wrap_submodules from torch._functorch._aot_autograd.input_output_analysis import ( @@ -312,18 +318,12 @@ def _get_param_buffer_mapping( of a traced module to what the original module contains. """ - param_lookup: Dict[int, List[str]] = {} - buffer_lookup: Dict[int, List[str]] = {} + param_lookup: Dict[int, str] = {} + buffer_lookup: Dict[int, str] = {} for name, param in original_module.named_parameters(remove_duplicate=False): - param_lookup.setdefault(id(param), []).append(name) + param_lookup[id(param)] = name for name, buffer in original_module.named_buffers(remove_duplicate=False): - buffer_lookup.setdefault(id(buffer), []).append(name) - - # reverse lists so FQN assignment is FIFO wrt model structure - for name, fqns in param_lookup.items(): - param_lookup[name] = fqns[::-1] - for name, fqns in buffer_lookup.items(): - buffer_lookup[name] = fqns[::-1] + buffer_lookup[id(buffer)] = name param_buffer_table: Dict[str, str] = {} for dynamo_name, dynamo_param in traced_module.named_parameters( @@ -331,14 +331,14 @@ def _get_param_buffer_mapping( ): assert dynamo_name not in param_buffer_table if id(dynamo_param) in param_lookup: - param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop() + param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)] for dynamo_name, dynamo_buffer in traced_module.named_buffers( remove_duplicate=False ): assert dynamo_name not in param_buffer_table if id(dynamo_buffer) in buffer_lookup: - param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop() + param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)] return param_buffer_table @@ -371,7 +371,11 @@ def _preserve_requires_grad_pass( assert spec.target is not None constant = constants[spec.target] if isinstance(constant, torch.Tensor): - node.meta["val"].requires_grad = constant.requires_grad + # If the tensor is not leaf, it should already have a correct requires grad field + if node.meta["val"].is_leaf: + node.meta["val"].requires_grad = constant.requires_grad + else: + assert node.meta["val"].requires_grad == constant.requires_grad elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): continue else: @@ -543,7 +547,7 @@ def _export_to_torch_ir( kwargs = kwargs or {} combined_args = _combine_args(f, args, kwargs) _check_dynamic_shapes(combined_args, dynamic_shapes) - _dynamic_shapes = _transform_shapes_for_default_dynamic( + transformed_dynamic_shapes = _transform_shapes_for_default_dynamic( combined_args, dynamic_shapes ) @@ -555,7 +559,7 @@ def _export_to_torch_ir( ), _ignore_backend_decomps(): gm_torch_level, _ = torch._dynamo.export( f, - dynamic_shapes=_dynamic_shapes, # type: ignore[arg-type] + dynamic_shapes=transformed_dynamic_shapes, # type: ignore[arg-type] tracing_mode="symbolic", disable_constraint_solver=disable_constraint_solver, # currently the following 2 flags are tied together for export purposes, @@ -595,6 +599,7 @@ def _export_to_aten_ir( *, transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. pre_dispatch=False, + decomp_table=None, _check_autograd_state=True, _is_torch_jit_trace=False, ) -> ATenExportArtifact: @@ -634,6 +639,7 @@ def _compiling_state_context(): fake_args, trace_joint=False, pre_dispatch=pre_dispatch, + decompositions=decomp_table, kwargs=fake_kwargs, ) @@ -668,9 +674,6 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. # Overwrite output specs afterwards. flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) - fake_mode = torch._export.utils._detect_fake_mode_from_gm(gm) - assert fake_mode is not None, "Cannot detect fake mode from graph" - if not torch._dynamo.config.do_not_emit_runtime_asserts: stack_trace = ( 'File "torch/fx/passes/runtime_assert.py", line 24, ' @@ -679,10 +682,11 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): with _set_node_metadata_hook( gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) ): - if fake_mode: + shape_env = _get_shape_env_from_gm(gm) + if shape_env: insert_deferred_runtime_asserts( gm, - fake_mode.shape_env, # type: ignore[arg-type] + shape_env, f"exported program: {first_call_function_nn_module_stack(gm.graph)}", export=True, ) @@ -1281,43 +1285,6 @@ def _strict_export_lower_to_aten_ir( attr, static_shapes=True ) - # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) - # from the param nodes as they are treated as fresh inputs - # Therefore, we manually extract them before calling into aot_export - params_buffers_to_node_meta = {} - for node in gm_torch_level.graph.nodes: - target = node.target - meta = node.meta - if node.op == "call_module": - submodule = getattr(gm_torch_level, target) - if isinstance(submodule, torch.nn.Module): - for name, _ in submodule.named_parameters( - recurse=True, remove_duplicate=False - ): - params_buffers_to_node_meta[target + "." + name] = meta - - for name, _ in submodule.named_buffers( - recurse=True, remove_duplicate=False - ): - params_buffers_to_node_meta[target + "." + name] = meta - - if node.op == "get_attr": - submodule = getattr(gm_torch_level, target) - if not isinstance(submodule, torch.fx.GraphModule): - params_buffers_to_node_meta[target] = meta - - # If the call_function uses param as input, we also need to update params' meta - # with this call_function node's meta. - # This is basically the same flow as torch.fx.traceback.preserve_meta() - if node.op == "call_function" and not isinstance( - node.target, torch._ops.HigherOrderOperator - ): - for arg in node._input_nodes: - if arg.op == "get_attr": - for entry in torch.fx.proxy._COPY_META_FIELDS: - if entry in meta: - params_buffers_to_node_meta[arg.target][entry] = meta[entry] - # Fix the graph output signature to be tuple if scalar out_spec = orig_out_spec = gm_torch_level._out_spec @@ -1342,6 +1309,13 @@ def _strict_export_lower_to_aten_ir( _normalize_nn_module_stack(gm_torch_level, type(mod)) + params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) + # from the param nodes as they are treated as fresh inputs + # Therefore, we manually extract them before calling into aot_export + # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + constant_attrs = _gather_constant_attrs(mod) param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) @@ -1370,26 +1344,9 @@ def _strict_export_lower_to_aten_ir( export_graph_signature = aten_export_artifact.sig constants = aten_export_artifact.constants - # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes - for metadata in params_buffers_to_node_meta.values(): - metadata.pop("nn_module_stack", None) - metadata.pop("stack_trace", None) - - # After aot_export, set the param/buffer metadata back into placeholders - # Technically, users can still construct this data from param names - # without relying on this metadata - for node in gm.graph.nodes: - if node.op == "placeholder": - if node.target in export_graph_signature.inputs_to_parameters: - param_name = export_graph_signature.inputs_to_parameters[node.target] - if param_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[param_name].items(): - node.meta[k] = v - if node.target in export_graph_signature.inputs_to_buffers: - buffer_name = export_graph_signature.inputs_to_buffers[node.target] - if buffer_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[buffer_name].items(): - node.meta[k] = v + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, export_graph_signature + ) # Do some cleanups on the graph module to restore the state dict to the # expected form. Each of these steps should probably get fixed upstream. @@ -1682,6 +1639,7 @@ def forward(self, *args, **kwargs): fake_kwargs, equalities_inputs, original_signature, + transformed_dynamic_shapes, ) = make_fake_inputs( mod, args, @@ -1697,7 +1655,7 @@ def _produce_guards_callback(gm): return produce_guards_and_solve_constraints( fake_mode=fake_mode, gm=gm, - dynamic_shapes=dynamic_shapes, + dynamic_shapes=transformed_dynamic_shapes, equalities_inputs=equalities_inputs, original_signature=original_signature, _is_torch_jit_trace=_is_torch_jit_trace, @@ -1742,6 +1700,7 @@ def _produce_guards_callback(gm): ) assert out_spec is not None + return ExportArtifact( aten=aten_export_artifact, out_spec=out_spec, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 77b2c1e34a9772..a84467767ac1eb 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import dataclasses import inspect +import logging import sys from collections import defaultdict from enum import auto, Enum @@ -30,20 +31,24 @@ __all__ = [ "Constraint", - "DIM", "Dim", "dims", "refine_dynamic_shapes_from_suggested_fixes", ] -class DIM(Enum): +log = logging.getLogger(__name__) + + +class _DimHint(Enum): """ - Enum for automatic/static dynamic shapes. + Enum for dynamic shape hints. + - AUTO means automatic inference of shape (static or dynamic). + - STATIC means static shape (always specialized). """ - STATIC = auto() AUTO = auto() + STATIC = auto() class _Dim(type): @@ -214,6 +219,7 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): Returns: A type that can be used in dynamic shape specifications for tensors. """ + from torch.utils._sympy.numbers import int_oo _min = 0 if min is None else min @@ -227,6 +233,10 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): return dim +Dim.AUTO = _DimHint.AUTO # type: ignore[attr-defined] +Dim.STATIC = _DimHint.STATIC # type: ignore[attr-defined] + + def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): """ Util to create multiple :func:`Dim` types. @@ -631,6 +641,15 @@ def find_shape(path, t): return dynamic_shapes +def _warn_on_None_dynamic_shape_dimension(): + msg = ( + "Using None as a dynamic shape dimension is deprecated. " + "Please use Dim.STATIC instead" + ) + # TODO(avik): raise an error in the future + log.warning(msg) + + def _check_dynamic_shapes( combined_args: Dict[str, Any], dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], @@ -668,24 +687,28 @@ def check_symbols(path, tensor, shape): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, DIM)) or dim is None): + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif isinstance(shape, (tuple, list)): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) - elif not (isinstance(dim, (int, DIM)) or dim is None): + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dimension #{i} in input tensor shape {shape} " f"specified at `dynamic_shapes{keystr(path)}` " - f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", + f"(expected None, an int, a Dim, Dim.AUTO, or Dim.STATIC, but got {dim} instead)", case_name="dynamic_shapes_validation", ) elif shape is not None: @@ -693,7 +716,7 @@ def check_symbols(path, tensor, shape): UserErrorType.INVALID_INPUT, f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," - f" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)", + f" where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)", case_name="dynamic_shapes_validation", ) @@ -740,18 +763,18 @@ def check_shape(path, t, dynamic_shape): _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") - # raise user warning if both DIM.AUTO & Dims are specified in dynamic_shapes + # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( - s == DIM.AUTO for s in flatter_dynamic_shapes + s == _DimHint.AUTO for s in flatter_dynamic_shapes ): raise UserError( UserErrorType.INVALID_INPUT, - "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " + "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " - "expect all equal or related dimensions to be specified, and does not yet compose well with `DIM.AUTO`. " - "We suggest using `DIM.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " + "expect all equal or related dimensions to be specified, and does not yet compose well with `Dim.AUTO`. " + "We suggest using `Dim.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " "if you want to assert on the exact specification of your program's dynamic shapes behavior.", case_name="dynamic_shapes_validation", @@ -773,8 +796,8 @@ def _transform_shapes_for_default_dynamic( for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: - - DIM.AUTO: dynamic, allocated a symbol - - None/unspecified/DIM.STATIC: static + - Dim.AUTO: dynamic, allocated a symbol + - None/unspecified/Dim.STATIC: static - Dim/DerivedDims: also a strict assertion To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes @@ -784,8 +807,8 @@ def _transform_shapes_for_default_dynamic( An example conversion might look like, for a 3-d input tensor: input spec: { - 0: DIM.AUTO, - 1: None, # or DIM.STATIC + 0: Dim.AUTO, + 1: None, # or Dim.STATIC 2: Dim("dx"), } output spec: { @@ -823,50 +846,50 @@ def _tree_map_helper(tree, val): combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] def transform_shapes(path, tensor, shape): - def _marked_dynamic(tensor, i): - # TODO(pianpwk): deprecate mark_dynamic() usage for export - return i in getattr(tensor, "_dynamo_dynamic_indices", set()) - out: Union[None, List[Any], Dict[int, Any]] = None if isinstance(shape, dict): out = {} for i, val in enumerate(tensor.shape): - dim = shape.get(i, None) - if _marked_dynamic(tensor, i) or dim == DIM.AUTO: + dim = shape.get(i, _DimHint.STATIC) + if dim == _DimHint.AUTO: # don't have to specify anything if dynamic # None also works, since assume_static_by_default=False - continue + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing elif isinstance(dim, _Dim): out[i] = dim elif isinstance(dim, int): # important that this is dim and not val, # so we can raise error if user-specified dim != val out[i] = dim + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + out[i] = val else: # make explicitly static - assert dim is None or dim == DIM.STATIC + assert dim == _DimHint.STATIC out[i] = val elif isinstance(shape, (tuple, list)): out = [] for i, val in enumerate(tensor.shape): dim = shape[i] - if _marked_dynamic(tensor, i) or dim == DIM.AUTO: + if dim == _DimHint.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, i) # avoid duck sizing out.append(None) elif isinstance(dim, _Dim): out.append(dim) elif isinstance(dim, int): out.append(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + out.append(val) else: - assert dim is None or dim == DIM.STATIC + assert dim == _DimHint.STATIC out.append(val) out = type(shape)(out) # type: ignore[assignment] else: assert shape is None if isinstance(tensor, torch.Tensor): - out = [] - for i, val in enumerate(tensor.shape): - out.append(None if _marked_dynamic(tensor, i) else val) - out = out or None + out = list(tensor.shape) or None else: out = None return out @@ -1036,8 +1059,12 @@ def _get_dim_name_mapping( dynamic_shapes, is_leaf=lambda x: isinstance(x, _Dim), )[0]: - if isinstance(dim, (int, DIM)) or dim is None: + if dim is None: + # NOTE: this must denote a non-Tensor or automatic at this point. + continue + if isinstance(dim, int): continue + assert isinstance(dim, _Dim) # dim hints should have boiled away name_to_dim[dim.__name__] = dim if isinstance(dim, _DerivedDim): name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] @@ -1118,9 +1145,11 @@ def refine_dynamic_shapes_from_suggested_fixes( name, expr = fix.split(" = ") expr = sympy.sympify(expr) if isinstance(expr, sympy.Number): - shape_fixes[name] = int(expr) # static, integer + # static, integer + shape_fixes[name] = int(expr) # type: ignore[assignment] else: - shape_fixes[name] = expr # relation or derived dim + # relation or derived dim + shape_fixes[name] = expr name_to_dim = _get_dim_name_mapping(dynamic_shapes) diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index c501a5b9953583..d830ed6dc65837 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -57,8 +57,8 @@ def _export_forward_backward( ep = _decompose_exported_program( ep, - decomp_table=core_aten_decompositions(), - _preserve_ops=(), # type: ignore[arg-type] + cia_to_decomp={}, + python_decomp_table=core_aten_decompositions(), joint_loss_index=joint_loss_index, ) gm, new_graph_signature = _copy_graph_module_and_signature(ep) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index d90ff9580f2dec..2788e804257ac1 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -5,7 +5,6 @@ import dataclasses import functools import operator -import re import types import warnings from collections import namedtuple @@ -26,8 +25,10 @@ from torch._higher_order_ops.utils import autograd_not_implemented from torch._library.fake_class_registry import FakeScriptObject +from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts if TYPE_CHECKING: @@ -41,15 +42,22 @@ import torch import torch.utils._pytree as pytree +from torch._export.utils import ( + _collect_and_set_constant_attrs, + _collect_param_buffer_metadata, + _detect_fake_mode_from_gm, + _name_hoo_subgraph_placeholders, + _overwrite_signature_for_non_persistent_buffers, + _populate_param_buffer_metadata_to_new_gm, + _rename_without_collisions, +) from torch._export.verifier import Verifier +from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import unset_fake_temporarily -from torch._subclasses.functional_tensor import FunctionalTensor from torch.export._tree_utils import is_equivalent, reorder_kwargs from torch.fx._compatibility import compatibility -from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager -from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from .graph_signature import ( # noqa: F401 ArgumentSpec, @@ -70,6 +78,7 @@ "ExportedProgram", "ModuleCallEntry", "ModuleCallSignature", + "core_aten_decompositions", ] @@ -83,6 +92,14 @@ class ModuleCallSignature: in_spec: pytree.TreeSpec out_spec: pytree.TreeSpec + def replace_all_uses_with(self, original_node, new_node): + for i in self.inputs: + if i.name == original_node.name: + i.name = new_node.name + for o in self.outputs: + if o.name == original_node.name: + o.name = new_node.name + @dataclasses.dataclass class ModuleCallEntry: @@ -157,7 +174,7 @@ def _register_cia_to_meta(*args, **kwargs): @contextmanager -def _override_composite_implicit_decomp(ops_to_preserve, decomp_table): +def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): # This function overrides CompositeImplicitAutograd decomp for # functional composite ops that user specified. Ideally we want to not-decompose # ALL composite ops but today's C++ functinalization relies on @@ -165,49 +182,17 @@ def _override_composite_implicit_decomp(ops_to_preserve, decomp_table): # Hence we can only do it for functional ops. One caveat is that # there are some composite ops that lie about their schema (claimed to be # functional but not really aka dropout), for these cases, we just decompose. - saved_tables = {} - patched_ops = set() - removed_decomps = {} - for op_overload in ops_to_preserve: - # Our strategy for deciding if we can preserve CIA is following: - # 1. The op should be known statically that it is functional - # 2. If it is maybe aliasing, we decompose because we must know if an op - # is mutating or aliasing. - # TODO (tmanlaibaatar) make this utility function and share it with functional_tensor - # decomp part. (https://github.com/pytorch/pytorch/issues/129431) - def assert_valid_to_preserve(op_overload): - if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: - raise RuntimeError( - f"We can't detect {op_overload} as a functional op statically, so we can't preserve it" - ) - if op_overload in FunctionalTensor.metadata_fns: - raise RuntimeError( - f"{op_overload} is a metadata query function, " - "it will be preserved implicitly in our tracing system. " - "Please file an issue on github if you see otherwise" - ) - - alias_info = len( - [i for i in op_overload._schema.arguments if i.alias_info is not None] - ) - - is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable - - if is_mutating_or_aliasing: - raise RuntimeError( - f"{op_overload} is a mutating/aliasing op, we can't preserve it as is" - ) - - if not torch._C._dispatch_has_kernel(op_overload.name()): - raise RuntimeError( - f"{op_overload} is a TorchScript op, we can't preserve it as is" - ) - return True - - # If we didn't error, it means we can go ahead - assert_valid_to_preserve(op_overload) + # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing + # and their usual decompositions need to be shadowed rather than overridden. + # Thus we will avoid asserting that they are valid to preserve, and will not + # replace their CompositeImplicitAutograd kernels with NotImplemented. + # The only current users of this mode are variants of aten::to that we will + # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. + saved_tables = {} + patched_ops = set() + for op_overload, decomp_callable in cia_ops_to_callable.items(): saved_tables[op_overload] = op_overload.py_kernels.copy() patched_ops.add(op_overload) @@ -220,10 +205,10 @@ def assert_valid_to_preserve(op_overload): if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] - def _(*args, **kwargs): - return NotImplemented - - op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(_) + if safe: + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( + decomp_callable + ) # For fake tensor prop, we do want to register meta kernel directly if torch._C.DispatchKey.Meta not in op_overload.py_kernels: @@ -231,10 +216,6 @@ def _(*args, **kwargs): functools.partial(_register_cia_to_meta, kernel=op_overload) ) - if op_overload in decomp_table: - removed_decomps[op_overload] = decomp_table[op_overload] - del decomp_table[op_overload] - try: yield finally: @@ -243,91 +224,91 @@ def _(*args, **kwargs): op.py_kernels.update(saved_tables[op]) op._dispatch_cache.clear() - for op, decomp in removed_decomps.items(): - decomp_table[op] = decomp +def _special_op_to_preserve_cia(*args, **kwargs): + "This is an special marker that tells our infra that we shouldn't decompose this op" + return NotImplemented -def _rename_without_collisions( - name_map: Dict[str, str], - orig_name: str, - name: str, - is_placeholder: bool = False, -): - """ - Renames nodes to avoid name collisions, with suffixing. - name_map: map from original name to new name - orig_name: mapping key - name: candidate name (potentially suffixed, e.g. mul_2) - is_placeholder: if the node is a placeholder, avoid detecting suffix - """ - if name in name_map.values(): - # non-placeholder nodes may be suffixed with the count - # instead of adding another suffix, we will try to increment it - match = re.match(r"(.*)_(\d+)", name) - if match and not is_placeholder: - name, n = match.group(1), int(match.group(2)) - else: - n = 0 - while (dup_name := f"{name}_{n + 1}") in name_map.values(): - n += 1 - name_map[orig_name] = dup_name - else: - name_map[orig_name] = name - return name_map[orig_name] - - -def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: + +@contextmanager +def _override_decomp_aten_to_variants(): + # Preserve variants of aten::to understanding that they are mutating/aliasing + # and their CompositeImplicitAutograd kernels will not become NotImplemented. + # We will later replace them with aten._to_copy when functionalizing. + with _override_composite_implicit_decomp( + { + torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia, + torch.ops.aten.to.dtype: _special_op_to_preserve_cia, + }, + safe=False, + ): + yield + + +def _split_decomp_table_to_cia_and_python_decomp( + decomp_table: Dict[torch._ops.OperatorBase, Callable] +) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]: + from torch._decomp import _collect_all_valid_cia_ops, _get_decomp_for_cia + + all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) + cia_ops_to_callable = {} + + for op in list(decomp_table.keys()): + # TODO we are silently allowing non-safe(non-functional) ops through a crack + # due to core aten decomp table having non-functional entries. Once we have + # a tigher check around core aten decomp, we should warn users about them. + # Tracking issue: (https://github.com/pytorch/pytorch/issues/135759) + + # if it is a valid CIA op we can mess with in export, we check if it is: + # 1. Has been marked as to be decomposed. Example: + # decomp_table = decomp_table_to_core_aten() + # del decomp_table[aten.linear] + # In this case, user says decompose everything except for aten.linear + # 2. Has been marked with custom decomp behavour. Example: + # decomp_table = {aten.linear: some_op} + # For (1), we want to remove all the CIA ops that weren't handled by user as + # it suggests they are safe to decompose, so we should remove from preservable_list. + # for (2), we just plumb the custom decomp to AOTDIspatcher. + # In both cases, we want to remove this CIA op from the decomp_table as it is special + # handled. + if op in all_preservable_cia_ops: + # TODO this is annpying case where aten.item has + # prim decomposition which later calls into aten.item + # and recurses infinitely. (https://github.com/pytorch/pytorch/issues/136050) + if op == torch.ops.aten.item.default: + cia_ops_to_callable[op] = _get_decomp_for_cia(op) + else: + cia_ops_to_callable[op] = decomp_table[op] + all_preservable_cia_ops.remove(op) + del decomp_table[op] + + # If we reached here, it means user intentionally deleted these CIA ops from + # decomp table. + for k in all_preservable_cia_ops: + cia_ops_to_callable[k] = _special_op_to_preserve_cia + + return cia_ops_to_callable, decomp_table + + +def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: """ - Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, - and handle collisions with non-placeholders by count suffixing. - Different HOO subgraph types have different input schemas, so we first enumerate them - and gather the top-level named placeholder nodes. + This is the default decomposition table which contains decomposition of + all ATEN operators to core aten opset. Use this API together with + :func:`run_decompositions()` """ - # gather all HOO subgraphs and their top-level named placeholder nodes - subgraph_ph_tuples: List[Tuple[torch.fx.GraphModule, List[torch.fx.Node]]] = [] - for node in gm.graph.nodes: - if node.op == "call_function" and isinstance( - node.target, torch._ops.HigherOrderOperator - ): - # HOO subgraphs have varying input schemas, so we enumerate them there - if node.target._name == "cond": - _, true_graph, false_graph, cond_args = node._args - subgraph_ph_tuples.append((getattr(gm, true_graph.target), cond_args)) - subgraph_ph_tuples.append((getattr(gm, false_graph.target), cond_args)) - elif node.target._name == "wrap_with_set_grad_enabled": - subgraph, phs = node._args[1], node._args[2:] - subgraph_ph_tuples.append((getattr(gm, subgraph.target), phs)) - elif node.target._name == "map_impl": - body_graph, array, args = node._args - subgraph_ph_tuples.append( - (getattr(gm, body_graph.target), array + args) - ) - - # propagate names - for subgraph, hoo_phs in subgraph_ph_tuples: - name_map: Dict[str, str] = {} - for i, node in enumerate(subgraph.graph.nodes): - if i < len(hoo_phs): # placeholder, retain name - name_map[node.name] = hoo_phs[i].name - node.name = node.target = hoo_phs[i].name - else: # non-placeholder, check for collisions - node.name = _rename_without_collisions(name_map, node.name, node.name) + from torch._decomp import core_aten_decompositions - # recurse and recompile - _name_hoo_subgraph_placeholders(subgraph) - subgraph.recompile() + return core_aten_decompositions() def _decompose_and_get_gm_with_new_signature_constants( ep, *, - decomp_table: Dict[torch._ops.OperatorBase, Callable], - _preserve_ops: Tuple[torch._ops.OpOverload], + cia_to_decomp: Dict[torch._ops.OperatorBase, Callable], + python_decomp_table: Dict[torch._ops.OperatorBase, Callable], joint_loss_index: Optional[int], ): - from torch._export.passes.lift_constants_pass import ConstantAttrMap from torch._functorch.aot_autograd import aot_export_module - from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode from torch.export._trace import ( _export_to_aten_ir, @@ -339,12 +320,18 @@ def _decompose_and_get_gm_with_new_signature_constants( ) from torch.fx.experimental.symbolic_shapes import ShapeEnv + # TODO Merge this path with inference IR decomp, but it will require some additional work + # so I will leave it for now. T200307782 if ep.verifier.dialect == "TRAINING": mod = ep.module() - fake_args = [ - node.meta["val"] for node in mod.graph.nodes if node.op == "placeholder" - ] - fake_mode = torch._export.utils._detect_fake_mode_from_gm(mod) + + fake_args = [] + for node in mod.graph.nodes: + if node.op == "placeholder": + fake_args.append(node.meta["val"]) + + fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) + fake_mode = _detect_fake_mode_from_gm(mod) if fake_mode is None: fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) @@ -370,51 +357,20 @@ def _decompose_and_get_gm_with_new_signature_constants( # the exported module will store constants & non-persistent buffers such that # retracing treats them as persistent buffers, so we inform the constants lifting pass # and overwrite the new graph signature using the previous program. - constant_attrs = ConstantAttrMap() - non_persistent_buffers = { - spec.target - for spec in ep.graph_signature.input_specs - if spec.kind == InputKind.BUFFER and not spec.persistent - } - for name, value in ep.constants.items(): - if name in non_persistent_buffers: - continue - # recursive getattr - _mod = mod - *atoms, attr = name.split(".") - for atom in atoms: - _mod = getattr(_mod, atom) - # remove as buffer, reassign as constant/non-persistent buffer - _mod._buffers.pop(attr, None) - setattr(_mod, attr, value) - constant_attrs.add(value, name) + constant_attrs = _collect_and_set_constant_attrs( + ep.graph_signature, ep.constants, mod + ) # get params & buffers after excluding constants fake_params_buffers = _fakify_params_buffers(fake_mode, mod) - params_buffers_to_node_meta = {} - for node in mod.graph.nodes: - target = node.target - meta = node.meta - if node.op == "get_attr": - params_buffers_to_node_meta[target] = meta - - # If the call_function uses param as input, we also need to update params' meta - # with this call_function node's meta. - # This is basically the same flow as torch.fx.traceback.preserve_meta() - if node.op == "call_function" and not isinstance( - node.target, torch._ops.HigherOrderOperator - ): - for arg in node._input_nodes: - if arg.op == "get_attr": - for entry in torch.fx.proxy._COPY_META_FIELDS: - if entry in meta: - params_buffers_to_node_meta[arg.target][entry] = meta[ - entry - ] - - with fake_mode: - fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) + params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) + + with _ignore_backend_decomps(), ( + fake_mode + ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( + cia_to_decomp, + ): aten_export_artifact = _export_to_aten_ir( mod, # this requires empty kwargs, but not in pytree.flattened format @@ -425,49 +381,27 @@ def _decompose_and_get_gm_with_new_signature_constants( {}, fake_params_buffers, constant_attrs, + decomp_table=python_decomp_table, _check_autograd_state=False, ) gm = aten_export_artifact.gm new_graph_signature = aten_export_artifact.sig - for node in gm.graph.nodes: - # nn_module_stack - if node.op not in ["placeholder", "output"]: - for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): - if isinstance(mod_cls, type): - node.meta["nn_module_stack"][key] = ( - fqn, - mod_cls.__module__ + "." + mod_cls.__qualname__, - ) - - # Don't copy over nn_module_stack, stack_trace metadata for params/buffers nodes - for metadata in params_buffers_to_node_meta.values(): - metadata.pop("nn_module_stack", None) - metadata.pop("stack_trace", None) - - for node in gm.graph.nodes: - if node.op == "placeholder": - if node.target in new_graph_signature.inputs_to_parameters: - param_name = new_graph_signature.inputs_to_parameters[node.target] - if param_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[param_name].items(): - node.meta[k] = v - if node.target in new_graph_signature.inputs_to_buffers: - buffer_name = new_graph_signature.inputs_to_buffers[node.target] - if buffer_name in params_buffers_to_node_meta: - for k, v in params_buffers_to_node_meta[buffer_name].items(): - node.meta[k] = v + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, new_graph_signature + ) # overwrite signature for non-persistent buffers - for spec in new_graph_signature.input_specs: - if spec.kind == InputKind.BUFFER and spec.target in non_persistent_buffers: - spec.persistent = False + new_graph_signature = _overwrite_signature_for_non_persistent_buffers( + ep.graph_signature, new_graph_signature + ) _verify_nn_module_stack(gm) _verify_stack_trace(gm) _verify_placeholder_names(gm, new_graph_signature) - return gm, new_graph_signature + + return _remove_unneccessary_copy_op_pass(gm, new_graph_signature) old_placeholders = [ node for node in ep.graph_module.graph.nodes if node.op == "placeholder" @@ -482,13 +416,12 @@ def _decompose_and_get_gm_with_new_signature_constants( fake_mode = detect_fake_mode(fake_args) fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( - _preserve_ops, - decomp_table, + cia_to_decomp ): gm, graph_signature = aot_export_module( ep.graph_module, fake_args, - decompositions=decomp_table, + decompositions=python_decomp_table, trace_joint=True if joint_loss_index is not None else False, output_loss_index=joint_loss_index if joint_loss_index is not None @@ -631,7 +564,34 @@ def update_arg(old_arg, new_ph): return gm, new_graph_signature -def _common_getitem_elimination_pass(gm: torch.fx.GraphModule, graph_signature): +def _remove_unneccessary_copy_op_pass( + gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature +) -> Tuple[torch.fx.GraphModule, ExportGraphSignature]: + """ + Removes redundant copy_ node that was introduced due to mutated buffer. + """ + with gm._set_replace_hook(new_graph_signature.get_replace_hook()): + for node in gm.graph.nodes: + if node.op == "output": + args, _ = pytree.tree_flatten(node.args) + for out in args: + if ( + isinstance(out, torch.fx.Node) + and out.name in new_graph_signature.buffers_to_mutate + ): + if ( + out.op == "call_function" + and out.target == torch.ops.aten.copy.default + ): + out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] + gm.graph.erase_node(out) + gm.recompile() + return gm, new_graph_signature + + +def _common_getitem_elimination_pass( + gm: torch.fx.GraphModule, graph_signature, module_call_graph +): with gm._set_replace_hook(graph_signature.get_replace_hook()): for module in gm.modules(): if not isinstance(module, torch.fx.GraphModule): @@ -639,12 +599,17 @@ def _common_getitem_elimination_pass(gm: torch.fx.GraphModule, graph_signature): node_id: Dict[torch.fx.Node, str] = {} getitems: Dict[str, torch.fx.Node] = {} - for node in module.graph.nodes: + for node in list(module.graph.nodes): if node.op == "call_function" and node.target == operator.getitem: source, idx = node.args new_id = f"{node_id[source]}.{idx}" if new_id in getitems: node.replace_all_uses_with(getitems[new_id]) + for entry in module_call_graph: + if entry.signature is not None: + entry.signature.replace_all_uses_with( + node, getitems[new_id] + ) module.graph.erase_node(node) else: getitems[new_id] = node @@ -656,14 +621,14 @@ def _common_getitem_elimination_pass(gm: torch.fx.GraphModule, graph_signature): def _decompose_exported_program( ep, *, - decomp_table: Dict[torch._ops.OperatorBase, Callable], - _preserve_ops: Tuple[torch._ops.OpOverload], + cia_to_decomp: Dict[torch._ops.OperatorBase, Callable], + python_decomp_table: Dict[torch._ops.OperatorBase, Callable], joint_loss_index: Optional[int], ): gm, new_graph_signature = _decompose_and_get_gm_with_new_signature_constants( ep, - decomp_table=decomp_table, - _preserve_ops=_preserve_ops, + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, joint_loss_index=joint_loss_index, ) @@ -727,7 +692,9 @@ def __init__( if isinstance(root, torch.fx.GraphModule): self._graph_module.meta.update(root.meta) - _common_getitem_elimination_pass(self._graph_module, graph_signature) + _common_getitem_elimination_pass( + self._graph_module, graph_signature, module_call_graph + ) self._graph_signature: ExportGraphSignature = graph_signature self._state_dict: Dict[str, Any] = state_dict self._range_constraints: Dict[sympy.Symbol, ValueRanges] = range_constraints @@ -1038,16 +1005,82 @@ def run_decompositions( `Core ATen Operator Set `_. For now, we do not decompose joint graphs. + + Args: + decomp_table: + An optional argument that specifies decomp behaviour for Aten ops + (1) If None, we decompose to core aten decompositions + (2) If empty, we don't decompose any operator + + + Some examples: + + If you don't want to decompose anything + + .. code-block:: python + + ep = torch.export.export(model, ...) + ep = ep.run_decompositions(decomp_table={}) + + If you want to get a core aten operator set except for certain operator, you can do following: + + .. code-block:: python + + ep = torch.export.export(model, ...) + decomp_table = torch.export.core_aten_decompositions() + decomp_table[your_op] = your_custom_decomp + ep = ep.run_decompositions(decomp_table=decomp_table) """ - from torch._decomp import core_aten_decompositions + from torch._decomp import ( + _decomp_table_to_post_autograd_aten, + core_aten_decompositions, + ) + from torch._inductor import config + + # FIXME delete this option after PTC, Executorch syncing is + # bit annoying so can't get rid of it easily + if _preserve_ops != (): + warnings.warn( + "This API is deprecated and soon will be removed. " + "Please look at the docstring to see how to preserve " + "an operator." + ) + + _decomp_table = ( + core_aten_decompositions() if decomp_table is None else dict(decomp_table) + ) - if decomp_table is None: - decomp_table = core_aten_decompositions() + if config.is_fbcode(): + # This means the decomp_table would only be containing post-autograd ops + # We should manually add CIA decomps + for k, v in _decomp_table_to_post_autograd_aten().items(): + _decomp_table[k] = v + + for op in _preserve_ops: + if op in _decomp_table: + del _decomp_table[op] + + # Note [Seperating decomp_table into CIA decomps and non-CIA decomps] + # At this point, we have a decomp_table that contains decomp behaviour for + # both CIA and post-autograd ops. + # We need to separate the op into two categories: + # 1. CIA op: These are the ops that we want to override + # CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp + # context manager to plumb it through AOTDispatcher + # 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just + # checking if they are statically functional is enough. + # For joint IR case tho, we need to use the old path because we can't register + # custom decomps this way because we can't use context manager as it installs + # autograd_error node. + ( + cia_to_decomp, + python_decomp_table, + ) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table) return _decompose_exported_program( self, - decomp_table=decomp_table, - _preserve_ops=_preserve_ops, # type: ignore[arg-type] + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, joint_loss_index=None, ) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 45f7b992827e79..2aa12e43770ee8 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -3,6 +3,7 @@ import copy import operator from collections import defaultdict +from contextlib import contextmanager from copy import deepcopy from enum import Enum from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union @@ -36,6 +37,20 @@ class _AttrKind(Enum): CONSTANT = "constant" +RUN_WITH_INTERPRETER = True + + +@contextmanager +def _disable_interpreter(): + global RUN_WITH_INTERPRETER + old_flag = RUN_WITH_INTERPRETER + RUN_WITH_INTERPRETER = False + try: + yield + finally: + RUN_WITH_INTERPRETER = old_flag + + # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module # This installs empty Modules where none exist yet if they are subpaths of target def _assign_attr( @@ -87,13 +102,18 @@ def __init__( super().__init__() self.graph = graph self.graph.owning_module = self + self._run_with_interpeter = RUN_WITH_INTERPRETER def forward(self, *args, **kwargs): assert self.graph_module is not None, "Didn't finalize this InterpreterModule" - if torch.compiler.is_dynamo_compiling(): + if not is_fx_tracing() and ( + torch.compiler.is_dynamo_compiling() or not self._run_with_interpeter + ): # Dynamo cannot trace through torch.fx.Interpreter, so fall back to # GraphModule codegen in this instance. - return self.graph_module(*args, **kwargs) + # Patch the codegened forward to run with this InterpreterModule, + # so attribute accesses, etc. are on this module instead. + return type(self.graph_module).forward(self, *args, **kwargs) else: if kwargs: # Handle **kwargs. FX only natively supports positional @@ -182,10 +202,12 @@ def __init__( export_graph = deepcopy(export_module.graph) self.graph_signature = deepcopy(export_module.graph_signature) self.graph = torch.fx.Graph() + self.graph.owning_module = self self.module_call_graph = deepcopy(export_module.module_call_graph) self.flat_args_adapter = flat_args_adapter # Flag to indicate whether args have been adapted. self.adapted = False + self._run_with_interpeter = RUN_WITH_INTERPRETER _inplace_buffer_mutations(export_graph, self.graph_signature) _outline_submodules(export_graph, self) @@ -407,6 +429,7 @@ def check_module_inputs(module, scope): assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list( fqn_order.keys() ) + self.graph.lint() def _print_graph(self): for fqn, mod in self.named_modules(): @@ -479,9 +502,12 @@ def forward(self, *args, **kwargs): _check_input_constraints_for_graph( self.input_placeholders, new_flat_args_with_path, self.range_constraints ) - tree_out = torch.fx.Interpreter(self, graph=self.graph).run( - *flat_args, enable_io_processing=False - ) + if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter: + tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args) + else: + tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) return pytree.tree_unflatten(tree_out, signature.out_spec) def print_readable( @@ -523,6 +549,8 @@ def unflatten( An instance of :class:`UnflattenedModule`, which has the same module hierarchy as the original eager module pre-export. """ + if module.verifier.dialect == "TRAINING": + raise RuntimeError("Unflattener doesn't support non-functional training IR yet") module = _remove_effect_tokens(module) return UnflattenedModule(module, flat_args_adapter) @@ -838,14 +866,8 @@ def copy_sym_call_function(self, x): # To avoid this we copy these call_function nodes with sym_type results. # This should however only be done for sym_type nodes - call_function nodes on tensors # should not be deduplicated in the first place. - args = tuple( - self.remap_input(_x) if isinstance(_x, torch.fx.Node) else _x - for _x in x.args - ) - kwargs = { - k: self.remap_input(_x) if isinstance(_x, torch.fx.Node) else _x - for k, _x in x.kwargs.items() - } + args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) + kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) node = self.graph.call_function(x.target, args, kwargs) node.meta = copy.copy(x.meta) self.node_map[x] = node @@ -870,7 +892,18 @@ def remap_input(self, x): with self.parent.graph.inserting_before(self.parent_call_module): self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) return self.node_to_placeholder[x] - elif x.op == "call_function": + elif x.op == "call_function" and ( + x.target + in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.item.default, + torch.ops.aten.unbind.int, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.view.default, + torch.ops.aten.diff.default, + ) + or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") + ): # export deduplicates sym_size nodes, and may need to re-copy them # if module call signature needs to be preserved self.copy_sym_call_function(x) @@ -879,7 +912,6 @@ def remap_input(self, x): raise RuntimeError( f"Could not run remap_input() on op type: {x.op} for node {x}" ) - return self.node_to_placeholder[x] def finalize_outputs(self): orig_outputs = [] diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index bdbd4b04429d31..6693863386513c 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import builtins import copy +import contextlib import functools import inspect import math @@ -793,13 +794,13 @@ def forward(*args, **kwargs): return _orig_module_call(mod, *args, **kwargs) _autowrap_check( - patcher, + patcher, # type: ignore[has-type] getattr(getattr(mod, "forward", mod), "__globals__", {}), self._autowrap_function_ids, ) return self.call_module(mod, forward, args, kwargs) - with _Patcher() as patcher: + with _new_patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method( torch.nn.Module, @@ -990,25 +991,36 @@ class _PatchedFn(NamedTuple): frame_dict: Any fn_name: str orig_fn: Any + new_fn: Any def revert(self): raise NotImplementedError + def patch(self): + raise NotImplementedError + class _PatchedFnSetItem(_PatchedFn): def revert(self): self.frame_dict[self.fn_name] = self.orig_fn + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn class _PatchedFnDel(_PatchedFn): def revert(self): del self.frame_dict[self.fn_name] + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + class _PatchedFnSetAttr(_PatchedFn): def revert(self): setattr(self.frame_dict, self.fn_name, self.orig_fn) + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) class _Patcher: def __init__(self) -> None: @@ -1028,14 +1040,15 @@ def patch( """ new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] if name not in frame_dict and hasattr(builtins, name): - self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) + self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) + self.patches_made[-1].patch() elif getattr(frame_dict[name], "__fx_already_patched", False): return # already patched, no need to do it again else: self.patches_made.append( - _PatchedFnSetItem(frame_dict, name, frame_dict[name]) + _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) ) - frame_dict[name] = new_fn + self.patches_made[-1].patch() def patch_method( self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True @@ -1047,8 +1060,8 @@ def patch_method( orig_fn = getattr(cls, name) if getattr(orig_fn, "__fx_already_patched", False): return # already patched, no need to do it again - self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) - setattr(cls, name, new_fn) + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() def visit_once(self, thing: Any): """Return True on the first call to with thing, otherwise false""" @@ -1058,6 +1071,22 @@ def visit_once(self, thing: Any): self.visited.add(idx) return True + def revert_all_patches(self): + """ + Remove all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.revert() + return self.patches_made + + def reapply_all_patches(self): + """ + Patch all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.patch() + return self.patches_made + def __enter__(self): return self @@ -1071,6 +1100,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.visited.clear() +CURRENT_PATCHER: Optional[_Patcher] = None + +@contextlib.contextmanager +def _new_patcher(): + global CURRENT_PATCHER + prior_patcher = CURRENT_PATCHER + try: + CURRENT_PATCHER = _Patcher() + yield CURRENT_PATCHER + finally: + # Clear all the patches made by when using current patcher. + assert CURRENT_PATCHER is not None + CURRENT_PATCHER.revert_all_patches() + CURRENT_PATCHER = prior_patcher + + +@contextlib.contextmanager +def _maybe_revert_all_patches(): + current_patcher = CURRENT_PATCHER + patches_made = None + patches_removed = None + try: + if current_patcher is not None: + patches_removed = current_patcher.revert_all_patches() + yield + finally: + if current_patcher is not None: + patches_made = current_patcher.reapply_all_patches() + assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" + def _patch_wrapped_functions(patcher: _Patcher): """ Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index ab9ef8c0c7a73c..213672e216e6ce 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from contextlib import contextmanager, ExitStack, nullcontext +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext from dataclasses import dataclass from typing import ( Any, @@ -596,6 +596,19 @@ def track_tensor_tree( constant: Optional[_NestedTensors], tracer: _ProxyTracer, ) -> T: + # NB: We call set_unbacked_bindings only on the *topmost* call to + # track_tensor_tree, not recursive calls. This is because there must + # be only ONE unbacked_binding proxy call, and it should be the one + # where all of the unbacked SymInts actually first come into existence. + # If you call this again on the inner proxies for the tuple projections, + # you will have multiple unbacked_bindings for the same symbol, but + # they're not going to show up anywhere. + # + # I was briefly deceived into setting unbacked bindings recursively when + # working on https://github.com/pytorch/pytorch/pull/133585 because I + # observed that some extra unbacked bindings were needed to handle some + # higher order operator code. But actually it looks like this was + # just an unrelated bug that needed to be fixed separately. _set_unbacked_bindings(inner_res, proxy_res) def wrap_with_proxy( @@ -1071,38 +1084,43 @@ def unwrap_proxy(self, e: T) -> object: return e -@contextmanager -def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: - from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode +def _make_temp_remove_mode_context_manager( + mode_ty: Type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode - temp_elements = [] - pre_dispatch_mode = None + temp_elements = [] + removed_mode = None - while _len_torch_function_stack() > 0: - mode = _pop_mode() - if isinstance(mode, PreDispatchTorchFunctionMode): - pre_dispatch_mode = mode - break - else: - temp_elements.append(mode) + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + for mode in reversed(temp_elements): + _push_mode(mode) - try: - yield + try: + yield removed_mode - finally: - if pre_dispatch_mode is not None: - count = len(temp_elements) - while count > 0: - mode = _pop_mode() - count -= 1 + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) - temp_elements.append(pre_dispatch_mode) + for mode in reversed(temp_elements): + _push_mode(mode) - for mode in reversed(temp_elements): - _push_mode(mode) + return context_manager_fn @torch._disable_dynamo @@ -1217,6 +1235,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + # This mode is **only** used for pre_dispatch tracing. # In particular, we need to make sure that autograd/autocast API's # that do not desugar into dispatcher operators stay in the graph. @@ -1245,6 +1268,11 @@ def __torch_function__( return func(*args, **kwargs) +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property @@ -1566,6 +1594,9 @@ def __getattr__(self, name: str) -> AttrProxy: ) return tracer.attr_proxy_map[attr_val] + def get_base(self) -> Module: + return tracer.proxy_modules[self] + @property def _modules(self) -> Dict[str, AttrProxy]: assert "_modules" in self.__dict__ @@ -2195,5 +2226,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) if fake_mode and fake_mode.shape_env: if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): - assert isinstance(out_proxy, Proxy) + assert isinstance(out_proxy, Proxy), out_proxy out_proxy.node.meta["unbacked_bindings"] = symbol_to_path diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index add6cedb428a45..268f56e3214fb9 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -436,9 +436,7 @@ def guard_int(self, file, line): def guard_float(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred - r = self.shape_env.evaluate_expr( - self.expr, self.hint, fx_node=self.fx_node, expect_rational=False - ) + r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) try: return float(r) except Exception: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 9c7f7ea80ce399..2243ab02f7f669 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -63,7 +63,7 @@ from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._sympy.functions import ( - FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt + Application, FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv, FloorToInt, CeilToInt ) from torch.utils._sympy.solve import try_solve from torch.utils._sympy.numbers import int_oo @@ -386,6 +386,40 @@ def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean: expr = sympy.logic.boolalg.to_cnf(expr) return _canonicalize_bool_expr_impl(expr) + +def _sympy_from_args( + cls: type, + args: List[sympy.Expr], + sort: bool = True, + is_commutative: Optional[bool] = None, +) -> sympy.Expr: + if not args: + return cls.identity + # These args are already in canonical form, so we avoid calling + # Add(*args) to avoid expensive Add.flatten operation + if sort: + if cls is sympy.Add: + sort_fn = sympy.core.add._addsort + elif cls is sympy.Mul: + sort_fn = sympy.core.mul._mulsort + else: + raise ValueError(f"Unknown cls: {cls}") + + # we don't support non commutative with sort + assert is_commutative is True + if args[0].is_Number: + rest = args[1:] + sort_fn(rest) + return cls._from_args([args[0]] + rest, is_commutative=is_commutative) + else: + args = args.copy() + sort_fn(args) + return cls._from_args(args, is_commutative=is_commutative) + else: + # if the args are already sorted, we create directly + return cls._from_args(args, is_commutative=is_commutative) + + def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: """ After canonicalization, we are guaranteed to have eliminated Ge/Gt relations @@ -404,7 +438,7 @@ def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: t = type(expr) def is_neg(t): - return t.is_negative or (isinstance(t, sympy.Mul) and t.args[0].is_negative) + return (t.is_Number and t.is_negative) or (isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative) lhs = 0 rhs = _reduce_to_lowest_terms(rhs) @@ -416,12 +450,16 @@ def is_neg(t): neg.append(-term) else: pos.append(term) - lhs = sympy.Add(*neg) - rhs = sympy.Add(*pos) + # these are already sorted + rhs = _sympy_from_args(sympy.Add, pos, sort=False, is_commutative=True) + # the terms were changed, so needs a sorting + lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True) elif is_neg(rhs): # lhs == 0 lhs, rhs = -rhs, 0 - return t(lhs, rhs) + # We don't have to evaluate here because lhs, rhs came from a Boolean + # and it was already simplified + return t(lhs, rhs, evaluate=False) def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: @@ -432,20 +470,39 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: Useful when an expression is == or != to 0. """ def integer_coefficient(x): - if isinstance(x, sympy.Integer): + if x.is_Integer: return abs(int(x)) - elif isinstance(x, sympy.Mul): - return math.prod([abs(int(arg)) for arg in x.args if isinstance(arg, sympy.Integer)]) + elif x.is_Mul: + # If one of the args of a Mul is an Integer, it is the + # first arg. eg: args(2*x*3*y) == (6, x, y) + return abs(int(x.args[0])) if x.args[0].is_Integer else 1 else: return 1 - if isinstance(expr, sympy.Add): + def div_by_factor(x, factor): + if x.is_Integer: + return x / factor + elif x.is_Mul: + if x.args[0] != factor: + args = [x.args[0] / factor, *x.args[1:]] + else: + # Mul._from_args require a canonical list of args + # so we remove the first arg (x.args[0] / factor) if it was 1 + args = list(x.args[1:]) + return _sympy_from_args(sympy.Mul, args, is_commutative=x.is_commutative) + + if expr.is_Add: atoms = expr.args factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) - atoms = [x / factor for x in atoms] - return sympy.Add(*atoms) - else: - return expr / integer_coefficient(expr) + if factor == 1: + return expr + atoms = [div_by_factor(x, factor) for x in atoms] + return _sympy_from_args(sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative) + elif expr.is_Integer: + return sympy.One + elif expr.is_Mul: + return div_by_factor(expr, integer_coefficient(expr)) + return expr def is_concrete_bool(a: Union[bool, SymBool]) -> bool: @@ -597,8 +654,6 @@ def compute_unbacked_bindings(shape_env, example_value, old_example_value=None, """ if shape_env is None: return - if shape_env._ignore_fresh_unbacked_symbols_tls(): - return fs = shape_env.pending_fresh_unbacked_symbols pending = set(fs) if pending: @@ -1267,13 +1322,14 @@ def _is_supported_equivalence(expr): ) return isinstance(expr, sympy.Symbol) -def _has_unsupported_sympy_function(expr) -> bool: +def _has_uninterpretable_sympy_function(expr) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ return expr.has( torch.utils._sympy.functions.ToFloat, torch.utils._sympy.functions.TruncToInt, torch.utils._sympy.functions.CeilToInt, - # add more sympy functions that involve float<->int conversion here - # since our solver does not know what to do with them ) @dataclass(frozen=True) @@ -1394,13 +1450,64 @@ def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + +def _expandsums(args: List[sympy.Expr]) -> Tuple[sympy.Expr, bool]: + adds, other = [], [] + for arg in args: + if arg.is_Add: + adds.append(arg) + else: + other.append(arg) + + result = [sympy.Mul(*other)] + for add in adds: + result = [a * b for a, b in itertools.product(result, add.args)] + + result = sympy.Add(*result) + return result, len(adds) > 1 or (len(adds) > 0 and len(other) > 0) + + +def _fast_expand(expr: sympy.Expr) -> sympy.Expr: + # The expand algorithm in sympy is slow due to all the features is supports + # For eg: e^(-x)*(x-1)/(x+1) is expanded to (x-1)/(e^x + e^x*x) if x is + # positive and (e^(-x)*x-e^(-x))/(x+1) if x is negative. We do not implement + # such features here to avoid expensive checks. We also make sure that we + # only re-create the objects if any of the args changed to avoid expensive + # checks when re-creating objects. + new_args = [_fast_expand(arg) for arg in expr.args] + if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): + return _fast_expand(expr.func(*new_args)) + + if expr.is_Pow: + base, exp = expr.args + if exp.is_Integer and base.is_Add: + if exp > 1: + return sympy.expand_multinomial(expr, deep=False) + elif exp < 0: + return 1 / sympy.expand_multinomial(1 / expr, deep=False) + elif expr.is_Mul: + num, den = [], [] + for arg in expr.args: + if arg.is_Pow and arg.args[1] == -1: + den.append(1 / arg) + else: + num.append(arg) + + num, num_changed = _expandsums(num) + den, den_changed = _expandsums(den) + if num_changed or den_changed: + return num / den + + return expr + + @lru_cache(256) def safe_expand(r): if hasattr(r, 'expand'): try: - return sympy.expand(r) + return _fast_expand(r) except RecursionError: - log.warning("RecursionError in sympy.expand(%s)", r) + log.warning("RecursionError in _fast_expand(%s)", r) return r else: return r @@ -1675,6 +1782,15 @@ def __init__( # symbols that are marked dynamic self._marked_dynamic = marked_dynamic + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: Set[sympy.Function] = { + Application, + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + def rewrite_with_congruences(self, s, expr): """ Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. @@ -1741,6 +1857,20 @@ def floor_div_handler(*args): expr = expr.replace(FloorDiv, floor_div_handler) return expr + def _enumerate_sympy_functions(self): + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions) + + def _has_unsupported_sympy_function(self, expr) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + def add(self, expr) -> bool: """Add an expression to the set of constraints. @@ -1757,7 +1887,7 @@ def add(self, expr) -> bool: # a fix for this issue, we delay raising such failures. See solve(). if orig_reduced == sympy.false: self._inconsistencies.append(f"{orig_expr} is inconsistent!") - if isinstance(expr, sympy.Ne) or _has_unsupported_sympy_function(expr): + if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr): # we're not going to do anything useful with these, so drop them return False free_symbols = expr.free_symbols @@ -2689,6 +2819,7 @@ def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr): def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: """Used only when propagate_real_tensors; registers a value for an unbacked symbol, which can be used last resort to resolve hints.""" + log.info("set_unbacked_var_to_val %s = %s", k, v) self.unbacked_var_to_val[k] = sympy.sympify(v) # Unlike set_replacement, this records a shapeenv event @@ -2698,8 +2829,6 @@ def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): assert isinstance(new_s, sympy.Symbol), new_s assert free_unbacked_symbols(new_s), new_s assert free_unbacked_symbols(orig_s), orig_s - if self._ignore_fresh_unbacked_symbols_tls(): - return dest = self.replacements.get(orig_s) assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" self._set_replacement(orig_s, new_s, "rename_unbacked_to") @@ -4404,7 +4533,7 @@ def add_expr(expr): @_lru_cache def _maybe_evaluate_static( self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False, - expect_rational=True, size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None, + size_oblivious: bool = False, axioms: Optional[Tuple[sympy.Expr]] = None, var_to_range: Optional[Tuple[Tuple[sympy.Symbol, ValueRanges]]] = None ) -> "Optional[sympy.Expr]": """ @@ -4442,7 +4571,7 @@ def _maybe_evaluate_static( subst = {} for e in axioms: if e.free_symbols.issubset(expr.free_symbols): - subst.update(dict(self.get_implications(e))) + subst.update(dict(self.get_implications(self.simplify(e)))) expr = expr.xreplace(subst) @@ -4549,8 +4678,18 @@ def _maybe_evaluate_static( def replace(self, expr: "sympy.Expr") -> "sympy.Expr": """Apply symbol replacements to any symbols in the given expression """ - replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} - return safe_expand(expr.xreplace(replacements)) + replacements = {} + for s in expr.free_symbols: + r = self._find(cast(sympy.Symbol, s)) + # Micro-optimization: only do replacements if r and s are different + # Otherwise, xreplace is not a no-op and will trigger expensive + # assumption queries if expr has a relational node. + if not r.is_Symbol or r != s: + replacements[s] = r + if replacements: + return safe_expand(expr.xreplace(replacements)) + else: + return expr @_lru_cache def _update_divisible(self): @@ -4584,8 +4723,9 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": if self.replace(Mod(base, divisor)) in self.divisible and \ base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible: div_replacements[atom] = divisor1 - expr = expr.xreplace(div_replacements) - expr = safe_expand(expr) + if div_replacements: + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) if expr.has(FloorDiv): div_replacements = {} pows = expr.atoms(sympy.Pow) @@ -4594,13 +4734,14 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: div_replacements[fd] = CleanDiv(base, divisor) - new_expr = expr.xreplace(div_replacements) - new_expr = safe_expand(new_expr) - new_pows = new_expr.atoms(sympy.Pow) - new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) - # divisions simplified away - if new_pows.issubset(pows) and new_rationals.issubset(rationals): - expr = new_expr + if div_replacements: + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer)) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr return expr @lru_cache(256) @@ -5097,18 +5238,18 @@ def _log_guard(self, prefix: str, g, forcing_spec: bool): @lru_cache(256) @record_shapeenv_event(save_tracked_fakes=True) def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, - expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False): + size_oblivious: bool = False, *, forcing_spec: bool = False): try: - return self._evaluate_expr(orig_expr, hint, fx_node, expect_rational, size_oblivious, forcing_spec=forcing_spec) + return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec) except Exception: self.log.warning( - "failed during evaluate_expr(%s, hint=%s, expect_rational=%s, size_oblivious=%s, forcing_spec=%s", - orig_expr, hint, expect_rational, size_oblivious, forcing_spec + "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", + orig_expr, hint, size_oblivious, forcing_spec ) raise def _evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None, - expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False): + size_oblivious: bool = False, *, forcing_spec: bool = False): """ Given an expression, evaluates it, adding guards if necessary """ @@ -5176,7 +5317,6 @@ def compute_concrete_val(): expr = orig_expr static_expr = self._maybe_evaluate_static(expr, - expect_rational=expect_rational, size_oblivious=size_oblivious) if static_expr is not None: self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr) @@ -5196,7 +5336,6 @@ def compute_concrete_val(): if not size_oblivious: size_oblivious_result = self._maybe_evaluate_static( expr, - expect_rational=expect_rational, size_oblivious=True ) @@ -5570,8 +5709,17 @@ def _suggest_fixes_for_data_dependent_error_non_strict(e): # map symbol names reachable via frame locals to their source-level names src_map = defaultdict(list) for var, val in frame.f_locals.items(): + try: + tree_leaves_with_path = pytree.tree_leaves_with_path(val) + except ValueError: + log.warning( + "pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}", + type(val), + var, + ) + continue # figure out how to access any symbol inside `val` through `var` - for path, leaf in pytree.tree_leaves_with_path(val): + for path, leaf in tree_leaves_with_path: name = var + pytree.keystr(path) if isinstance(leaf, torch.SymInt): src_map[str(leaf.node.expr)].append(name) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 62daf414ec4f4a..b0df9f02fcb8e6 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -20,6 +20,7 @@ import math import warnings import inspect +import functools __all__ = ["PythonCode", "CodeGen", "Graph"] @@ -37,6 +38,8 @@ tuple: Tuple, } +_legal_ops = dict.fromkeys(['call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output']) + # Signature for functions thattransforms the body (`list[str]`) of the # generated code @@ -84,14 +87,11 @@ def _snake_case(s: str) -> str: ``mod.pascalCase``-> ``mod.pascal_case`` ``mod.ALL_CAPS`` -> ``mod.all_caps`` """ - chars = [] - prev_lower = False - for c in s: - if prev_lower and c.isupper(): - chars.append('_') - chars.append(c.lower()) - prev_lower = c.islower() - return ''.join(chars) + return _snake_case_sub(s).lower() + + +# Replace occurrences where a lowercase letter is followed by an uppercase letter +_snake_case_sub = functools.partial(re.compile(r'(?<=[a-z])([A-Z])').sub, r'_\1') def _is_from_torch(obj: Any) -> bool: @@ -816,10 +816,10 @@ def remove(self, node: Node) -> None: def find_nodes(self, *, op: str, target: Optional['Target'] = None): if op == "call_function": assert target is not None - return dict(self.table[(op, target)]).keys() + return [*self.table[(op, target)].keys()] if target is None: - return dict(self.table[(op, None)]).keys() + return [*self.table[(op, None)].keys()] # op is call_method, get_attr, call_module return [node for node in self.table[(op, None)].keys() if node.target == target] @@ -1010,7 +1010,7 @@ def create_node(self, op: str, target: 'Target', The newly-created and inserted node. """ - assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') + assert op in _legal_ops args = () if args is None else args kwargs = {} if kwargs is None else kwargs assert isinstance(args, tuple), "args must be a tuple" @@ -1551,6 +1551,8 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: # Check targets are legit if self.owning_module: + num_warnings = 0 + MAX_WARNINGS = 5 for node in self.nodes: if node.op == 'call_function': if not callable(node.target): @@ -1577,11 +1579,21 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: and not isinstance(new_m_itr, torch.nn.Module) and not isinstance(new_m_itr, torch.nn.Parameter) and atom not in m_itr._buffers): - warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module, nn.Parameter, or buffer, which is ' - 'what \'get_attr\' Nodes typically target') + if num_warnings < MAX_WARNINGS: + # Don't emit this warning too frequently, + # for very large graphs this can become very expensive + # from a performance perspective. + warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' + 'not reference an nn.Module, nn.Parameter, or buffer, which is ' + 'what \'get_attr\' Nodes typically target') + num_warnings += 1 else: m_itr = new_m_itr + if num_warnings > MAX_WARNINGS: + warnings.warn( + f'Additional {num_warnings - MAX_WARNINGS} warnings ' + 'suppressed about get_attr references' + ) @compatibility(is_backward_compatible=True) def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): diff --git a/torch/fx/node.py b/torch/fx/node.py index 218fcebe332107..8c3461cbe23c7c 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -33,6 +33,8 @@ BaseArgumentTypes ]] +_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) + _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { torch._C._set_grad_enabled, torch.amp._enter_autocast, @@ -164,6 +166,18 @@ class Node(_NodeBase): - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. """ + _args: Tuple['Argument', ...] + _kwargs: Dict[str, 'Argument'] + graph: 'Graph' + name: str + op: str + target: 'Target' + _input_nodes: Dict['Node', None] + users: Dict['Node', None] + type: Optional[Any] + _sort_key: Any + _repr_fn: Optional[Callable[['Node'], str]] + meta: Dict[str, Any] @compatibility(is_backward_compatible=True) def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', @@ -195,11 +209,7 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', annotation of values in the generated code or for other types of analyses. """ - super().__init__() - self.graph = graph - self.name = name # unique name of value being created - assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root'] - self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + assert op in _legal_ops if op == 'call_function': if not callable(target): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' @@ -208,21 +218,31 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', if not isinstance(target, str): raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' 'but a str is expected') - self.target = target # for method/module/function, the name of the method/module/function/attr + super().__init__() + + # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects + assign = object.__setattr__ + + assign(self, "graph", graph) + assign(self, "name", name) # unique name of value being created + assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + + assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. # The public API for this is `all_input_nodes`, this private attribute # should not be accessed directly. - self._input_nodes : Dict[Node, None] = {} - self.__update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore[arg-type] + assign(self, "_input_nodes", {}) + self.__update_args_kwargs(args, kwargs) # All of the nodes that use the value produced by this Node # Note one user may correspond to several uses, e.g. the node fo ``x + x`` # would appear once here, but represents two uses. # # Is a dict to act as an "ordered set". Keys are significant, value dont-care - self.users : Dict[Node, None] = {} + assign(self, "users", {}) + # Type expression representing the output value of this node. # This should contain the same class of Type objects that would appear # as type annotations for function inputs/outputs. @@ -233,15 +253,15 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. - self.type : Optional[Any] = return_type - self._sort_key: Any = () + assign(self, "type", return_type) + assign(self, "_sort_key", ()) # If set, use this fn to print this node - self._repr_fn : Optional[Callable[[Node], str]] = None + assign(self, "_repr_fn", None) # Dictionary to store metadata passes need to do their # transformations. This metadata is preserved across node copies - self.meta : Dict[str, Any] = {} + assign(self, "meta", {}) def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() @@ -364,7 +384,7 @@ def args(self, a : Tuple[Argument, ...]) -> None: """ # DO NOT CALL `__update_args_kwargs` directly. The correct way to # set `args` is via direct assignment, i.e. `node.args = new_args` - self.__update_args_kwargs(map_arg(a, lambda x: x), self._kwargs) # type: ignore[arg-type] + self.__update_args_kwargs(a, self._kwargs) @property def kwargs(self) -> Dict[str, Argument]: @@ -387,7 +407,7 @@ def kwargs(self, k : Dict[str, Argument]) -> None: """ # DO NOT CALL `__update_args_kwargs` directly. The correct way to # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` - self.__update_args_kwargs(self._args, map_arg(k, lambda x: x)) # type: ignore[arg-type] + self.__update_args_kwargs(self._args, k) @property def all_input_nodes(self) -> List['Node']: @@ -453,9 +473,7 @@ def update_kwarg(self, key : str, arg : Argument) -> None: key (str): The key in ``self.kwargs`` of the element to update arg (Argument): The new argument value to write into ``kwargs`` """ - kwargs = dict(self.kwargs) - kwargs[key] = arg - self.kwargs = kwargs + self.kwargs = {**self.kwargs, key: arg} @property def stack_trace(self) -> Optional[str]: @@ -479,18 +497,23 @@ def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : D """ This API is internal. Do *not* call it directly. """ - self._args = new_args - self._kwargs = new_kwargs + def update_users_and_input_nodes(n: Any) -> Any: + if isinstance(n, Node): + self._input_nodes.setdefault(n) + n.users.setdefault(self) + return n + # Clear prior users and input_nodes for old_use in self._input_nodes.keys(): old_use.users.pop(self) + object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__ - self._input_nodes = {} - map_arg(self._args, self._input_nodes.setdefault) - map_arg(self._kwargs, self._input_nodes.setdefault) - - for new_use in self._input_nodes.keys(): - new_use.users.setdefault(self) + # We do three things in a single pass of the args + # - Normalize list->immutable_list, dict->immutable_dict, etc + # - Populate self._input_nodes + # - Populate arg.users[self] for each arg + object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) + object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) def __repr__(self) -> str: if self._repr_fn: @@ -765,13 +788,16 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """ if isinstance(a, tuple): - t = tuple(map_aggregate(elem, fn) for elem in a) + t = tuple([map_aggregate(elem, fn) for elem in a]) # Support NamedTuple (if it has `_fields`) by repacking into original type. return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] elif isinstance(a, list): - return immutable_list(map_aggregate(elem, fn) for elem in a) + return immutable_list([map_aggregate(elem, fn) for elem in a]) elif isinstance(a, dict): - return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) + rv = immutable_dict() + for k, v in a.items(): + dict.__setitem__(rv, k, map_aggregate(v, fn)) + return rv elif isinstance(a, slice): return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) else: diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 7307fb8c835703..01803600d021ab 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import functools import logging import operator import sys @@ -92,8 +93,9 @@ def insert_deferred_runtime_asserts( # Import sympy locally import sympy + from torch._export.passes._node_metadata_hook import _set_node_metadata_hook from torch.fx.experimental.symbolic_shapes import ( - _has_unsupported_sympy_function, + _has_uninterpretable_sympy_function, CallMethodKey, cast_symbool_to_symint_guardless, ConvertIntKey, @@ -136,7 +138,7 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: (val := _get_sym_val(node)) is not None and not isinstance(val, sympy.Number) # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported - and not _has_unsupported_sympy_function(val) + and not _has_uninterpretable_sympy_function(val) and any( isinstance(arg, fx.Node) and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) @@ -145,6 +147,36 @@ def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: ) ) + # Figure out what key to use, val or example_value + val_key = "val" + for node in graph.nodes: + if "example_value" in node.meta: + val_key = "example_value" + break + elif "val" in node.meta: + break + + def _node_metadata_hook( + node: torch.fx.Node, + stack_trace: Optional[str] = None, + nn_module_stack: Optional[Dict[str, Any]] = None, + ) -> None: + fake_args = [ + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ] + try: + node.meta[val_key] = node.target(*fake_args) # type: ignore[operator] + except NotImplementedError: + # This can happen when attempting to reify a symbol with an unsupported call_function node, + # e.g. with NestedTensors + sym_size.int via match_symbol(). + # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input. + pass + if stack_trace is not None: + node.meta["stack_trace"] = stack_trace + if nn_module_stack is not None: + node.meta["nn_module_stack"] = nn_module_stack + # Track asserts/checks we've added added_asserts: Set[sympy.Expr] = set() constrained_unbacked_symbols: Set[sympy.Symbol] = set() @@ -195,7 +227,7 @@ def add_runtime_asserts(ras): and _is_bound_expr_for_symbol(ra.expr) ) # don't try to reify sympy functions we can't turn into FX nodes - or _has_unsupported_sympy_function(ra.expr) + or _has_uninterpretable_sympy_function(ra.expr) ): continue @@ -211,16 +243,17 @@ def add_runtime_asserts(ras): else: # Convert the sympy expression into a sequence of FX # nodes - res = _sympy_interp(expr_to_proxy, ra.expr).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - # TODO: use ra.msg here, but it's pretty - # useless right now - ( - res, - f"Runtime assertion failed for expression {ra.expr} on node '{res}'", - ), - ) + with _set_node_metadata_hook(gm, _node_metadata_hook): + res = _sympy_interp(expr_to_proxy, ra.expr).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + # TODO: use ra.msg here, but it's pretty + # useless right now + ( + res, + f"Runtime assertion failed for expression {ra.expr} on node '{res}'", + ), + ) added_asserts.add(ra.expr) nodes = list(graph.nodes) @@ -247,7 +280,8 @@ def match_symbol(symint, cb): and isinstance(s := symint.node.expr, sympy.Symbol) and s not in expr_to_proxy ): - expr_to_proxy[s] = fx.Proxy(cb()) + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(cb()) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) match_symbol(example_value, lambda: node) @@ -331,7 +365,15 @@ def match_symbol(symint, cb): if _is_intermediate_tensor_sym_call( node ): # reify from input shapes - expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) @@ -436,7 +478,8 @@ def go(node, keypath): raise AssertionError(f"unrecognized keypath {keypath}") if s not in expr_to_proxy: - expr_to_proxy[s] = fx.Proxy(go(node, keypath)) + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(go(node, keypath)) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) for i0 in defs: @@ -519,26 +562,34 @@ def convert(s): # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts # raises AOTAutograd errors on cast_symbool_to_symint_guardless - if (min_val := convert(vr.lower)) is not None: - ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - ge, - f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", - ), - ) - added_asserts.add(i0 >= min_val) - if (max_val := convert(vr.upper)) is not None: - le = _sympy_interp(expr_to_proxy, i0 <= max_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - le, - f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", - ), - ) - added_asserts.add(i0 <= max_val) + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) constrained_unbacked_symbols.add(i0) add_runtime_asserts(ras) diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 324e8a67801564..11a9cfa34898ab 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -118,7 +118,7 @@ def fuse_as_graphmodule(gm: GraphModule, for node in nodes: assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" assert not node._erased, f"{node} has been removed from owning graph" - assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" + assert node in gm.graph._find_nodes_lookup_table, f"{node} is not found in graph module {gm._get_name()}" # validates partition doesn't introduce dependency circles in the graph assert validate_partition(nodes), "Invalid partition, found dependency cycles" diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 8482dca74b1002..78b063ff8a7aaa 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -61,8 +61,8 @@ def target_graph(x, weight): relu *= 2 return relu - pattern_gm = capture_pre_autograd_graph(pattern, example_inputs) - target_gm = capture_pre_autograd_graph(target_graph, example_inputs) + pattern_gm = export_for_training(pattern, example_inputs).module() + target_gm = export_for_training(target_graph, example_inputs).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) matches = matcher.match(target_gm) for match in matches: diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 2b86a1c609f918..86927595eac91e 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -97,6 +97,7 @@ def __exit__(self, *args): "original_aten", "recompute", "ac_graph_id", + "has_backward_hook", "from_node", "quantization_tag", # TODO deprecated "_numeric_debug_handle", # TODO deprecated @@ -152,6 +153,7 @@ def create_node(self, kind : str, target : Target, self.scope.module_path, self.scope.module_type, ) + # Optionally set stack trace on the created Node for debugging purposes if fx_traceback.has_preserved_node_meta(): current_meta: Dict[str, Any] = fx_traceback.get_current_meta() @@ -260,29 +262,33 @@ def create_arg(self, a: Any) -> Argument: Can be override to support more trace-specific types. """ - if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): + if isinstance(a, Proxy): + return a.node # most common arg type goes first + elif hasattr(a, '__fx_create_arg__'): return a.__fx_create_arg__(self) # aggregates - elif isinstance(a, tuple) and hasattr(a, '_fields'): - # NamedTuple constructors don't seem to like getting a generator - # expression as an argument to their constructor, so build this - # intermediate tuple and unpack it into the NamedTuple constructor - args = tuple(self.create_arg(elem) for elem in a) - return type(a)(*args) # type: ignore[arg-type] - elif isinstance(a, (tuple, list)): - return type(a)(self.create_arg(elem) for elem in a) + elif isinstance(a, tuple): + if hasattr(a, '_fields'): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = [self.create_arg(elem) for elem in a] + return type(a)(*args) # type: ignore[arg-type] + return type(a)([self.create_arg(elem) for elem in a]) + elif isinstance(a, list): + return [self.create_arg(elem) for elem in a] elif isinstance(a, dict): + def no_node(arg): + if isinstance(arg, Node): + raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}") + r = {} for k, v in a.items(): # Check for invalid dict keys. We do not want a Proxy to appear # anywhere within the key. Since keys can be collection types, # we iterate through the key with map_aggregate k = self.create_arg(k) - - def no_node(arg): - if isinstance(arg, Node): - raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " - f"Node. Got key: {k}") map_aggregate(k, no_node) r[k] = self.create_arg(v) @@ -296,16 +302,13 @@ def no_node(arg): elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): return a - if isinstance(a, Proxy): - # base case: we unwrap the Proxy object - return a.node - - if is_dataclass(a): + elif is_dataclass(a): kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} return self.create_node("call_function", a.__class__, (), kwargs) elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: return a + raise NotImplementedError(f"argument of type: {type(a)}") @compatibility(is_backward_compatible=True) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 7f2cb743d2cdd9..c0d88821d7fafc 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -207,9 +207,11 @@ def forward(self, x, w1, w2): def replace_pattern_with_filters( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], - replacement: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, ) -> List[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -219,17 +221,22 @@ def replace_pattern_with_filters( (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating whether the match satisfies the condition. See matcher_utils.py for definition of InternalMatch. + ``replacement_callback``: A function that takes in a match and returns a + Graph to be used as the replacement. This allows you to construct a + replacement graph based on the match. """ - return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals) + return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback) def _replace_pattern( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], - replacement: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, ) -> List[ReplacedPatterns]: from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch @@ -247,13 +254,6 @@ def _replace_pattern( else: pattern_graph = symbolic_trace(pattern).graph - if isinstance(replacement, GraphModule): - replacement_graph = replacement.graph - elif isinstance(replacement, Graph): - replacement_graph = replacement - else: - replacement_graph = symbolic_trace(replacement).graph - matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, remove_overlapping_matches=True, ignore_literals=ignore_literals) _matches: List[InternalMatch] = matcher.match(original_graph) @@ -265,13 +265,27 @@ def _replace_pattern( for match_filter in match_filters) ] - replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + if isinstance(replacement, GraphModule): + common_replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + common_replacement_graph = replacement + elif callable(replacement): + common_replacement_graph = symbolic_trace(replacement).graph + else: + assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback" + common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change match_changed_node: Dict[Node, Node] = {} match_and_replacements = [] - for match in _matches: + for i, match in enumerate(_matches): + if replacement_callback is not None: + replacement_graph = replacement_callback(match, original_graph, pattern_graph) + else: + assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback" + replacement_graph = common_replacement_graph + replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] # Build connecting between replacement graph's input and original graph input producer node diff --git a/torch/hub.py b/torch/hub.py index 096e463fcba10f..c037c6e9dc139a 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -12,7 +12,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from typing_extensions import deprecated from urllib.error import HTTPError, URLError from urllib.parse import urlparse # noqa: F401 @@ -389,7 +389,7 @@ def _load_entry_from_hubconf(m, model): return func -def get_dir(): +def get_dir() -> str: r""" Get the Torch Hub cache directory used for storing downloaded models & weights. @@ -408,7 +408,7 @@ def get_dir(): return os.path.join(_get_torch_home(), "hub") -def set_dir(d): +def set_dir(d: Union[str, os.PathLike]) -> None: r""" Optionally set the Torch Hub directory used to save downloaded models & weights. diff --git a/torch/library.py b/torch/library.py index 3fc44f191ba383..4ac89c5259b071 100644 --- a/torch/library.py +++ b/torch/library.py @@ -1228,17 +1228,35 @@ def opcheck( ``opcheck`` tests these metadata and properties. Concretely, we test the following: - - test_schema: if the operator's schema is correct. - - test_autograd_registration: if autograd was registered correctly. + + - test_schema: If the schema matches the implementation of + the operator. For example: if the schema specifies a Tensor is mutated, + then we check the implementation mutates the Tensor. If the schema + specifies that we return a new Tensor, then we check that the + implementation returns a new Tensor (instead of an existing one or + a view of an existing one). + - test_autograd_registration: If the operator supports training + (autograd): we check that its autograd formula is registered via + torch.library.register_autograd or a manual registration to one + or more DispatchKey::Autograd keys. Any other DispatchKey-based + registrations may lead to undefined behavior. - test_faketensor: If the operator has a FakeTensor kernel - (and if it is correct). The FakeTensor kernel is necessary ( - but not sufficient) for the operator to work with PyTorch compilation - APIs (torch.compile/export/FX). + (and if it is correct). The FakeTensor kernel is necessary ( + but not sufficient) for the operator to work with PyTorch compilation + APIs (torch.compile/export/FX). We check that a FakeTensor kernel + (also sometimes known as a meta kernel) was registered for the + operator and that it is correct. This test takes the result of + running the operator on real tensors and the result of running + the operator on FakeTensors and checks that they have the same + Tensor metadata (sizes/strides/dtype/device/etc). - test_aot_dispatch_dynamic: If the operator has correct behavior - with PyTorch compilation APIs (torch.compile/export/FX). - This checks that the outputs (and gradients, if applicable) are the - same under eager-mode PyTorch and torch.compile. - This test is a superset of ``test_faketensor``. + with PyTorch compilation APIs (torch.compile/export/FX). + This checks that the outputs (and gradients, if applicable) are the + same under eager-mode PyTorch and torch.compile. + This test is a superset of ``test_faketensor`` and is an e2e test; + other things it tests are that the operator supports + functionalization and that the backward pass (if it exists) also + supports FakeTensor and functionalization. For best results, please call ``opcheck`` multiple times with a representative set of inputs. If your operator supports diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 3377d30c665477..cef76fec1107d5 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -1,5 +1,3 @@ -import sys - import torch from torch._C import _add_docstr, _linalg # type: ignore[attr-defined] @@ -2234,7 +2232,7 @@ equations with a unique solution. Args: - A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) + A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= False`) where `*` is zero or more batch dimensions. B (Tensor): right-hand side tensor of shape `(*, n, k)`. diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 72f8249785d288..4962d0430992e6 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1,7 +1,7 @@ -# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import warnings -from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec import torch from torch import sym_float, Tensor @@ -23,13 +23,16 @@ __all__: List[str] = [] +_T = TypeVar("_T") +_P = ParamSpec("_P") + # All masked reduction/normalization operations have the same # signatures. Here we introduce docstring templates that are applied # to docstrings of reduction/normalization functions via # _apply_docstring_templates decorator. -def _apply_docstring_templates(func): +def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]: """Decorator that applies docstring templates to function docstring and returns the function instance. """ @@ -418,9 +421,16 @@ def _reduction_identity(op_name: str, input: Tensor, *args): return torch.tensor(0, dtype=dtype, device=device) elif op_name in {"prod", "cumprod"}: return torch.tensor(1, dtype=dtype, device=device) - elif op_name in {"amax", "argmax", "logsumexp"}: + elif op_name in {"amax", "argmax", "logaddexp"}: + if torch.is_floating_point(input): + return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_signed(input) or dtype == torch.uint8: + return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) + elif op_name in {"logsumexp"}: if torch.is_floating_point(input): return torch.tensor(-torch.inf, dtype=dtype, device=device) + elif torch.is_complex(input): + return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device) elif op_name in {"amin", "argmin"}: @@ -1523,8 +1533,8 @@ def logaddexp( if dtype is None: dtype = input.dtype if input.layout == torch.strided and other.layout == torch.strided: - mask_input = _combine_input_and_mask(logsumexp, input, input_mask) - mask_other = _combine_input_and_mask(logsumexp, other, other_mask) + mask_input = _combine_input_and_mask(logaddexp, input, input_mask) + mask_other = _combine_input_and_mask(logaddexp, other, other_mask) return torch.logaddexp(mask_input, mask_other).to(dtype=dtype) else: raise ValueError( diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 69850df81c65a1..366cf45eb2d50f 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -258,7 +258,7 @@ def _set_data_mask(self, data, mask): self._masked_mask = mask self._validate_members() - def __repr__(self): + def __repr__(self): # type: ignore[override] formatter = "{0:8.4f}" if self.dim() == 0: scalar_data = self.get_data().item() @@ -350,7 +350,7 @@ def get_mask(self): def is_sparse_coo(self): return self.layout == torch.sparse_coo - def is_sparse_csr(self): + def is_sparse_csr(self): # type: ignore[override] return self.layout == torch.sparse_csr # Update later to support more sparse layouts diff --git a/torch/masked/maskedtensor/passthrough.py b/torch/masked/maskedtensor/passthrough.py index 4a2e79456c86db..ba13f50c1fee9c 100644 --- a/torch/masked/maskedtensor/passthrough.py +++ b/torch/masked/maskedtensor/passthrough.py @@ -30,6 +30,11 @@ torch.ops.aten._reshape_alias, torch.ops.aten.cat, torch.ops.aten.unsqueeze, + torch.ops.aten.unfold, + torch.ops.aten.unfold_backward, + torch.ops.aten.im2col, + torch.ops.aten.col2im, + torch.ops.aten.stack, ] diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index eaa107e973f6d9..af711d49260600 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -166,6 +166,17 @@ def memory_stats(device: Optional[_device_t] = None) -> Dict[str, Any]: return torch._C._mtia_memoryStats(_get_device_index(device, optional=True)) +def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: + r"""Return capability of a given device as a tuple of (major version, minor version). + + Args: + device (torch.device or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). + """ + return torch._C._mtia_getDeviceCapability(_get_device_index(device, optional=True)) + + def set_stream(stream: Stream): r"""Set the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` @@ -323,6 +334,7 @@ def set_rng_state( "current_stream", "default_stream", "memory_stats", + "get_device_capability", "set_device", "set_stream", "stream", diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index fa0818571a93c0..4e4539396f8323 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -2,7 +2,7 @@ import multiprocessing import os import threading -from multiprocessing.reduction import ForkingPickler +from multiprocessing import reduction from multiprocessing.util import register_after_fork from typing import Union @@ -74,7 +74,7 @@ def __init__(self) -> None: def _after_fork(self): self.lock = threading.Lock() - def get(self, key): + def get(self, key): # type: ignore[override] with self.lock: return dict.get(self, key) @@ -626,22 +626,22 @@ def reduce_storage(storage): def init_reductions(): - ForkingPickler.register(torch.cuda.Event, reduce_event) + reduction.register(torch.cuda.Event, reduce_event) for t in torch._storage_classes: if t.__name__ == "UntypedStorage": - ForkingPickler.register(t, reduce_storage) + reduction.register(t, reduce_storage) else: - ForkingPickler.register(t, reduce_typed_storage_child) + reduction.register(t, reduce_typed_storage_child) - ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) + reduction.register(torch.storage.TypedStorage, reduce_typed_storage) for t in torch._tensor_classes: - ForkingPickler.register(t, reduce_tensor) + reduction.register(t, reduce_tensor) # TODO: Maybe this should be in tensor_classes? :) - ForkingPickler.register(torch.Tensor, reduce_tensor) + reduction.register(torch.Tensor, reduce_tensor) from torch.nn.parameter import Parameter - ForkingPickler.register(Parameter, reduce_tensor) + reduction.register(Parameter, reduce_tensor) diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 756e9aa670c9ca..74bdde0fd97b20 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -233,6 +233,7 @@ def start_processes( start_method == "forkserver" and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" ): + log.info("Starting processes in parallel.") start_parallel = True else: # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 6a0425a13d43af..d39eb12d919c87 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -209,7 +209,7 @@ def _max_seqlen(self): def _min_seqlen(self): return self._get_min_seqlen() - def __repr__(self): + def __repr__(self): # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( f", requires_grad={self.requires_grad}" if self.requires_grad else "" @@ -562,3 +562,28 @@ def nested_view_from_values_offsets_lengths( min_seqlen_tensor, max_seqlen_tensor, ) # type: ignore[return-value] + + +def nested_from_padded( + padded, offsets, ragged_idx=1, min_seqlen=None, max_seqlen=None, sum_S=None +): + if ragged_idx != 1: + raise RuntimeError("nested_from_padded(): only ragged_idx=1 supported for now") + + min_seqlen_tensor = None + if min_seqlen is not None: + min_seqlen_tensor = _store_val_in_tensor(min_seqlen) + + max_seqlen_tensor = None + if max_seqlen is not None: + max_seqlen_tensor = _store_val_in_tensor(max_seqlen) + + return torch._nested_from_padded_tensor( + padded, + offsets, + _nt_view_dummy(), + ragged_idx, + min_seqlen_tensor, + max_seqlen_tensor, + sum_S, + ) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 0d203270b820b2..14f227e0886230 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -346,6 +346,29 @@ def _flatten_sig(input, start_dim=0, end_dim=-1): return inp.reshape(*new_shape) + # Handle nested-specific input validation for CompositeImplicit rms_norm + if func.__name__ == "rms_norm": + + def _rms_norm_sig(input, normalized_shape, weight=None, eps=None): + pass + + _, new_kwargs = normalize_function( # type: ignore[misc] + _rms_norm_sig, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + normalized_shape = new_kwargs.pop("normalized_shape") + + # can't normalize over the ragged dim (yet) + max_normalizable = inp.dim() - inp._ragged_idx - 1 + if len(normalized_shape) > max_normalizable: + raise ValueError( + "rms_norm(): Normalization over the ragged dim not supported for nested tensors" + ) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + raise NotImplementedError(func) @@ -395,7 +418,7 @@ def prim_layout_default(func, *args, **kwargs): def tensor_attr_unsupported_getter(func, *args, **kwargs): if func == torch.ops.aten.size.default: raise RuntimeError( - "NestedTensors does not support directly calling torch.ops.aten.size " + "NestedTensor does not support directly calling torch.ops.aten.size; " "please use `nested_tensor.size()` instead." ) @@ -1017,6 +1040,17 @@ def is_same_size_default(func, *args, **kwargs): return args[0]._size == args[1]._size +@register_jagged_func(torch.ops.aten.sum.default, "self: jt_all, dtype: any?") +def sum_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + return func(inp._values, **new_kwargs) + + @register_jagged_func( torch.ops.aten.sum.dim_IntList, "self: jt_all, dim: any?, keepdim: any?, dtype: any?", @@ -1097,8 +1131,7 @@ def sum_dim_IntList(func, *args, **kwargs): ) # need to read offsets --> pad jagged dimension and apply sum if new_kwargs["keepdim"]: - # TODO: Fix this; it's a bug. should be unsqueezing on ragged_idx - out = out.unsqueeze(0) + out = out.unsqueeze(inp._ragged_idx) return out else: # raggedness preserved --> return nested tensor if ( @@ -1159,6 +1192,44 @@ def transpose_int(func, *args, **kwargs): return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp)) +@register_jagged_func(torch.ops.aten.permute.default, "self: jt_all, dims: any") +def permute_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + dims = new_kwargs.pop("dims") + inp_kwargs = extract_kwargs(inp) + inp_dim = len(inp._size) + + # The first two checks are the same as the checks in the normal permute implementation + if inp_dim != len(dims): + raise ValueError( + f"permute(): number of dimensions in the tensor input ({inp_dim}) " + + f"does not match the length of the desired ordering of dimensions ({len(dims)}).", + ) + + from torch._prims_common import canonicalize_dims + + canonicalized_dims = canonicalize_dims(inp_dim, dims) + + if len(canonicalized_dims) != len(set(canonicalized_dims)): + raise ValueError("permute(): duplicate dims are not allowed.") + + if inp._lengths is not None: + raise ValueError( + "permute(): not supported on jagged layout nested tensor with holes" + ) + if canonicalized_dims[0] != 0: + raise ValueError( + "Permute is not supported on the batch dimension for jagged NT" + ) + inp_kwargs["_ragged_idx"] = canonicalized_dims.index(inp._ragged_idx) + inner_dims = [_outer_to_inner_dim(inp_dim, dim) for dim in canonicalized_dims[1:]] + new_kwargs["dims"] = inner_dims + return NestedTensor(func(inp._values, **new_kwargs), **inp_kwargs) + + @register_jagged_func( [torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default], "self: jt_all, size: any", @@ -1386,13 +1457,16 @@ def mean_dim(func, *args, **kwargs): func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) - if len(new_kwargs["dim"]) > 1: + inp = new_kwargs.pop("input") + + if len(new_kwargs["dim"]) > 1 and ( + inp._ragged_idx in new_kwargs["dim"] or 0 in new_kwargs["dim"] + ): raise RuntimeError( - "mean(): not supported across multiple dimensions for NestedTensor" + "mean(): not supported across multiple dimensions for NestedTensor " + "when either the batch dim or ragged dim is included" ) - inp = new_kwargs.pop("input") - ( new_kwargs["dim"], reduce_on_batch, @@ -1425,8 +1499,10 @@ def mean_dim(func, *args, **kwargs): # for every non-batch dimension, # unsqueeze lengths into the same shape as the PyTorch sum, # as the extra dimensions must all be divided by the same length + # Note: keepdim=True is on at this point so lengths has to be unsqueezed for + # that 1-size dim as well. lengths = inp._offsets.diff() - for _ in range(inp.dim() - 2): + for _ in range(inp.dim() - 1): lengths = lengths.unsqueeze(-1) return torch_sum / lengths.broadcast_to(torch_sum.shape) @@ -1515,6 +1591,98 @@ def all_default(func, *args, **kwargs): return func(inp._values) +@register_jagged_func( + torch.ops.aten.to_padded_tensor.default, "self: jt, padding: any, output_size: any?" +) +def to_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + + # TODO: Handle the rest of output_size + output_size = new_kwargs["output_size"] + if output_size is not None: + max_seq_len = output_size[inp._ragged_idx] + else: + max_seq_len = inp._max_seqlen + + # only 2D values is supported by the underlying FBGEMM kernel so do shape + # gymnastics if needed + values = inp.values() + values_shape = values.shape + if values.dim() > 2: + values = values.flatten(start_dim=1) + elif values.dim() == 1: + values = values.unsqueeze(-1) + + padded_out = torch.ops.aten._jagged_to_padded_dense_forward( + values, + [inp._offsets], + [max_seq_len], + new_kwargs["padding"], + ) + + # shape gymnastics part 2 + if len(values_shape) > 2: + padded_out = padded_out.unflatten(-1, values_shape[1:]) + elif len(values_shape) == 1: + padded_out = padded_out.squeeze(-1) + + return padded_out + + +@register_jagged_func( + torch.ops.aten._nested_from_padded_tensor.default, + "padded: t, offsets: t, dummy: jt, ragged_idx: any?, min_seqlen: any?, max_seqlen: any?, sum_S: any?", +) +def _nested_from_padded_tensor_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + if new_kwargs["ragged_idx"] != 1: + raise RuntimeError( + "_nested_from_padded_tensor(): only ragged_idx=1 supported for jagged layout" + ) + + padded, offsets = new_kwargs["padded"], new_kwargs["offsets"] + + # non-3D padded is not supported by the underlying FBGEMM kernel so do shape gymnastics + padded_shape = padded.shape + if padded.dim() > 3: + padded = padded.flatten(start_dim=2) + elif padded.dim() < 3: + padded = padded.unsqueeze(-1) + + values = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], new_kwargs["sum_S"] + ) + + # shape gymnastics part 2 + if len(padded_shape) > 3: + values = values.unflatten(-1, padded_shape[2:]) + elif len(padded_shape) < 3: + values = values.squeeze(-1) + + ragged_idx = new_kwargs["ragged_idx"] + min_seqlen = new_kwargs["min_seqlen"] + max_seqlen = new_kwargs["max_seqlen"] + metadata_cache = {} + if min_seqlen is not None: + metadata_cache["min_seqlen"] = min_seqlen + if max_seqlen is not None: + metadata_cache["max_seqlen"] = max_seqlen + + return NestedTensor( + values, + offsets, + _ragged_idx=ragged_idx, + _metadata_cache=metadata_cache, + ) + + @register_jagged_func( torch.ops.aten._nested_view_from_jagged.default, "values: t, offsets: t, dummy: jt_all, lengths: t?, ragged_idx: any?, min_seqlen: t?, max_seqlen: t?", diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 73d078d04c8a17..578904af946971 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -615,6 +615,20 @@ def _post_process_flash_output(out: torch.Tensor, og_size): return out +def _is_computing_meta_flops(x): + # Note: there's a use case of using meta tensors & the dispatch-based flop counter. + # We can use this function to check for this scenario in order to handle it specially. + if not torch.jit.is_scripting() and x.device.type == "meta": + torch_dispatch_mode_stack = ( + torch.utils._python_dispatch._get_current_dispatch_mode_stack() + ) + return any( + type(x) == torch.utils.flop_counter.FlopCounterMode + for x in torch_dispatch_mode_stack + ) + return False + + def _autocast( query: torch.Tensor, key: torch.Tensor, @@ -642,7 +656,8 @@ def _autocast( actual dtype conversions. """ device_type = query.device.type - if not torch.is_autocast_enabled(device_type): + # meta device is not supported by autocast, so break early for it + if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type): return query, key, value, attn_mask def cvt(x): @@ -703,6 +718,13 @@ def jagged_scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, is_causal, enable_gqa ) + if _is_computing_meta_flops(query): + # Backend choice will probably not be correct if we have a meta device, + # because backend choice is device-aware. In this case, we mostly just + # want to avoid using math backend (which does a .item() call). + # Arbitrarily choose flash attention. + backend_choice = SDPBackend.FLASH_ATTENTION + if backend_choice == SDPBackend.FLASH_ATTENTION: og_size = query.size(-1) query_padded = _pad_last_dim(query, 8, False) diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 5567db5919c18c..8cb0a9ac2dea8e 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -1,21 +1,14 @@ # mypy: allow-untyped-defs """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ import contextlib -from typing import List, Union +from typing import Iterable, List, Union from warnings import warn +import torch.backends.cuda from torch._C import _SDPBackend as SDPBackend from torch.backends.cuda import ( can_use_efficient_attention, can_use_flash_attention, - cudnn_sdp_enabled, - enable_cudnn_sdp, - enable_flash_sdp, - enable_math_sdp, - enable_mem_efficient_sdp, - flash_sdp_enabled, - math_sdp_enabled, - mem_efficient_sdp_enabled, SDPAParams, ) @@ -67,6 +60,32 @@ def _raise_kernel_warnings(params: SDPAParams) -> None: can_use_flash_attention(params, True) +_backend_names = { + "cudnn": "CUDNN_ATTENTION", + "flash": "FLASH_ATTENTION", + "mem_efficient": "EFFICIENT_ATTENTION", + "math": "MATH", +} + + +def _backend_from_string(name: str): + return getattr(SDPBackend, name) + + +def _cur_sdpa_kernel_backends(): + backends: List[SDPBackend] = [] + for name, val in _backend_names.items(): + if getattr(torch.backends.cuda, f"{name}_sdp_enabled")(): + backends.append(getattr(SDPBackend, val)) + return backends + + +def _sdpa_kernel(backends: Iterable[SDPBackend]): + for name, val in _backend_names.items(): + enabled = getattr(SDPBackend, val) in backends + getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled) + + @contextlib.contextmanager def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]): r""" @@ -102,26 +121,19 @@ def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]): backends = [backends] backends = set(backends) - previous_cudnn: bool = cudnn_sdp_enabled() - previous_flash: bool = flash_sdp_enabled() - previous_mem_efficient: bool = mem_efficient_sdp_enabled() - previous_math: bool = math_sdp_enabled() + previous_backends = _cur_sdpa_kernel_backends() try: - enable_cudnn = SDPBackend.CUDNN_ATTENTION in backends - enable_flash = SDPBackend.FLASH_ATTENTION in backends - enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends - enable_math = SDPBackend.MATH in backends - - enable_cudnn_sdp(enable_cudnn) - enable_flash_sdp(enable_flash) - enable_mem_efficient_sdp(enable_mem_efficient) - enable_math_sdp(enable_math) + _sdpa_kernel(backends) yield {} finally: - enable_cudnn_sdp(previous_cudnn) - enable_flash_sdp(previous_flash) - enable_mem_efficient_sdp(previous_mem_efficient) - enable_math_sdp(previous_math) + _sdpa_kernel(previous_backends) + + +# variadic version of sdpa_kernel for dynamo to use while reconstructing +@contextlib.contextmanager +def _sdpa_kernel_variadic(*backends: SDPBackend): + with sdpa_kernel(list(backends)): + yield def _get_flash_version() -> str: diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 2fed60030a6898..da7acb957d96d7 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -289,7 +289,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): ) return cls._dispatch(*args, **kwargs) - def __repr__(self): + def __repr__(self): # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 5b93b1e018f6ae..2317dbac83bdda 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,6 +19,7 @@ ) from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( + _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input @@ -51,7 +52,6 @@ class _ModificationType(Enum): UNKNOWN = 3 -@torch._dynamo.assume_constant_result def _get_mod_type(fn: Callable) -> _ModificationType: """Get the type of modification function. This function inspects the number of positional arguments of the function to determine @@ -392,12 +392,59 @@ def __str__(self): return s def __getitem__(self, index) -> "BlockMask": - mapped_attributes = tree_map_only( - torch.Tensor, - lambda x: x[index], - self.as_tuple(flatten=False), + """ + Returns a new BlockMask instance by getting the mask for the given index position. + + Args: + index: Index to apply to all attributes. + + Example Usage: + .. code-block:: python + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda") + assert block_mask.kv_num_blocks.shape == (4,2,4) + assert block_mask.kv_indices.shape == (4,2,4,4) + + # Index on batch dimension + new_block_mask = block_mask[0] + assert new_block_mask.kv_num_blocks.shape == (2,4) + assert new_block_mask.kv_indices.shape == (2,4,4) + + # Index on batch and head dimension + new_block_mask = block_mask[0, 1] + assert new_block_mask.kv_num_blocks.shape == (4,) + assert new_block_mask.kv_indices.shape == (4,4) + + # slicing on batch and head dimension + new_block_mask = block_mask[0:2, 1:2] + assert new_block_mask.kv_num_blocks.shape == (2,1,4) + assert new_block_mask.kv_indices.shape == (2,1,4,4) + + # slicing on batch, head, and query dimension + new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)] + assert new_block_mask.kv_num_blocks.shape == (2,1,1) + assert new_block_mask.kv_indices.shape == (2,1,1,4) + """ + new_kv_num_blocks = self.kv_num_blocks[index] + new_kv_indices = self.kv_indices[index] + if self.full_kv_num_blocks is not None: + assert self.full_kv_indices is not None + new_full_kv_num_blocks = self.full_kv_num_blocks[index] + new_full_kv_indices = self.full_kv_indices[index] + else: + new_full_kv_num_blocks = None + new_full_kv_indices = None + return BlockMask.from_kv_blocks( + new_kv_num_blocks, + new_kv_indices, + new_full_kv_num_blocks, + new_full_kv_indices, + BLOCK_SIZE=self.BLOCK_SIZE, + mask_mod=None, ) - return BlockMask(*mapped_attributes) def __repr__(self): def shape_or_none(x: Optional[torch.Tensor]): @@ -824,7 +871,9 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: ) -def _apply_kernel_options(query, key, value, kernel_options): +def _apply_kernel_options( + query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options +): kernel_options = {} if kernel_options is None else dict(kernel_options) kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) @@ -832,11 +881,13 @@ def _apply_kernel_options(query, key, value, kernel_options): # If foward kernel needs to return logsumexp is decided by this rule internally. assert "OUTPUT_LOGSUMEXP" not in kernel_options - any_inputs_require_grad = ( - query.requires_grad or key.requires_grad or value.requires_grad - ) - output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() - kernel_options.setdefault("OUTPUT_LOGSUMEXP", output_logsumexp) + kernel_options["OUTPUT_LOGSUMEXP"] = True + if not return_lse: + any_inputs_require_grad = ( + query.requires_grad or key.requires_grad or value.requires_grad + ) + output_logsumexp = any_inputs_require_grad and torch.is_grad_enabled() + kernel_options["OUTPUT_LOGSUMEXP"] = output_logsumexp return kernel_options @@ -953,10 +1004,17 @@ def score_mod( if scale is None: scale = 1.0 / math.sqrt(query.size(-1)) + if query.device != block_mask.kv_num_blocks.device: + raise RuntimeError( + f"Expect q/k/v and block_mask to be on the same device " + f"but got {query.device} and {block_mask.kv_num_blocks.device}." + ) + kernel_options = _apply_kernel_options( query, key, value, + return_lse, kernel_options, ) @@ -976,6 +1034,10 @@ def score_mod( if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") + from torch._dynamo.backends.debugging import ( + make_eager_backend_with_torch_function_mode, + ) + # Dynamo is expecting a callable with "__code__" attribute. # We cannot directly pass hop to it. So we wrap it in a dummy function. def _flex_attention_hop_wrapper(*args, **kwargs): @@ -984,18 +1046,25 @@ def _flex_attention_hop_wrapper(*args, **kwargs): with _set_compilation_env(): with torch._dynamo.utils.disable_cache_limit(): with _temp_remove_pre_dispatch_torch_function_mode(): - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend="eager", fullgraph=True - )( - query, - key, - value, - score_mod, - block_mask.as_tuple(), - scale, - kernel_options, - ) - if return_lse: - return out, lse * math.log(2) - else: - return out + with _temp_remove_metadata_torch_function_mode() as metadata_mode: + if metadata_mode: + backend = make_eager_backend_with_torch_function_mode( + metadata_mode + ) + else: + backend = "eager" + out, lse = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + )( + query, + key, + value, + score_mod, + block_mask.as_tuple(), + scale, + kernel_options, + ) + if return_lse: + return out, lse * math.log(2) + else: + return out diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 4bae0cfafb1799..9640ca1e76e290 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5620,7 +5620,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) @@ -5631,7 +5631,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: - attn_bias += attn_mask + attn_bias = attn_mask + attn_bias if enable_gqa: key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) @@ -5745,7 +5745,7 @@ def forward(self, ...): >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") - >>> with torch.backends.cuda.sdp_kernel(enable_math=False): + >>> with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): >>> F.scaled_dot_product_attention(query,key,value) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 4889a4485c4901..02369cb4974b41 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -979,11 +979,11 @@ class MultiheadAttention(Module): Multi-Head Attention is defined as: .. math:: - \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O - where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. - ``nn.MultiHeadAttention`` will use the optimized implementations of + ``nn.MultiheadAttention`` will use the optimized implementations of ``scaled_dot_product_attention()`` when possible. In addition to support for the new ``scaled_dot_product_attention()`` diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 2a3b24b9b280bf..cfcc447625ea3b 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -18,13 +18,13 @@ class AdaptiveLogSoftmaxWithLoss(Module): - r"""Efficient softmax approximation. + """Efficient softmax approximation. As described in `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin, Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou `__. - +""" r""" Adaptive softmax is an approximate strategy for training models with large output spaces. It is most effective when the label distribution is highly imbalanced, for example in natural language modelling, where the word diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index ccb628dff6a313..170c806e47fefc 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -241,11 +241,13 @@ class Conv1d(_ConvNd): * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or a tuple of ints giving the amount of implicit padding applied on both sides. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also - known as the \uue0 trous algorithm. It is harder to describe, but this `link`_ + known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} Note: @@ -268,14 +270,14 @@ class Conv1d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` """.format( **reproducibility_notes, **convolution_notes @@ -404,10 +406,13 @@ class Conv2d(_ConvNd): * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" {groups_note} @@ -438,13 +443,13 @@ class Conv2d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` """.format( **reproducibility_notes, **convolution_notes ) @@ -574,9 +579,12 @@ class Conv3d(_ConvNd): * :attr:`padding` controls the amount of padding applied to the input. It can be either a string {{'valid', 'same'}} or a tuple of ints giving the amount of implicit padding applied on both sides. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. +""" + r""" {groups_note} @@ -607,10 +615,10 @@ class Conv3d(_ConvNd): stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple or str, optional): Padding added to all six sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` """.format( **reproducibility_notes, **convolution_notes ) @@ -834,10 +842,12 @@ class ConvTranspose1d(_ConvTransposeNd): * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} Note: @@ -996,10 +1006,12 @@ class ConvTranspose2d(_ConvTransposeNd): * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` @@ -1184,10 +1196,12 @@ class ConvTranspose3d(_ConvTransposeNd): * :attr:`output_padding` controls the additional size added to one side of the output shape. See note below for details. - +""" + """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but the link `here`_ has a nice visualization of what :attr:`dilation` does. - +""" + r""" {groups_note} The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`output_padding` @@ -1453,14 +1467,14 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` .. seealso:: :class:`torch.nn.Conv1d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` """ @@ -1522,14 +1536,14 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` .. seealso:: :class:`torch.nn.Conv2d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` """ @@ -1592,14 +1606,14 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 - padding_mode (str, optional): ``'zeros'``, ``'reflect'``, - ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` + padding_mode (str, optional): ``'zeros'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` .. seealso:: :class:`torch.nn.Conv3d` and :class:`torch.nn.modules.lazy.LazyModuleMixin` """ diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 54d97d5a58bf10..58397caec32c8c 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -42,10 +42,10 @@ class Fold(Module): * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. - +""" """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - +""" r""" Args: output_size (int or tuple): the shape of the spatial dimensions of the output (i.e., ``output.sizes()[2:]``) @@ -194,10 +194,10 @@ class Unfold(Module): * :attr:`padding` controls the amount of implicit zero-paddings on both sides for :attr:`padding` number of points for each dimension before reshaping. - +""" """ * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does. - +""" r""" Args: kernel_size (int or tuple): the size of the sliding blocks dilation (int or tuple, optional): a parameter that controls the diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index a15850553f1f56..f4796e50e415a1 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1746,9 +1746,11 @@ def _call_impl(self, *args, **kwargs): or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*args, **kwargs) - try: - result = None - called_always_called_hooks = set() + result = None + called_always_called_hooks = set() + + def inner(): + nonlocal result, args, kwargs full_backward_hooks, non_full_backward_hooks = [], [] backward_pre_hooks = [] @@ -1826,6 +1828,20 @@ def _call_impl(self, *args, **kwargs): return result + from torch.compiler import is_compiling + + # This is technically not behavior equivalent when compiling, but it's + # incredibly unlikely we will ever support throwing an exception in NN + # module, and then catching it here, and then reraising it, and then + # catching it again, and expecting the resulting frame to be compiled. + # The reraise here just gunks up our exception handling for no good + # reason. Don't try to run the always called hooks in event of + # exception. + if is_compiling(): + return inner() + + try: + return inner() except Exception: # run always called hooks if they have not already been run # For now only forward hooks have the always_call option but perhaps diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 09a38e1119de6b..13fa6324833196 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -285,10 +285,10 @@ def pack_padded_sequence( ) -> PackedSequence: r"""Packs a Tensor containing padded sequences of variable length. - :attr:`input` can be of size ``T x B x *`` where ``T`` is the length of the - longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions - (including 0). If :attr:`batch_first` is ``False``, ``T x B x *`` :attr:`input` is expected, - ``B x T x *`` otherwise. + :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``) + or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length + of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions + (including 0). For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is ``True``, the sequences should be sorted by length in a decreasing order, i.e. diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 69994f315f77e6..dcb80a94c950d1 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -186,7 +186,7 @@ def _reparametrize_module( def functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], - args: Union[Any, Tuple], + args: Optional[Union[Any, Tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, @@ -264,7 +264,7 @@ def functional_call( def _functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], - args: Union[Any, Tuple], + args: Optional[Union[Any, Tuple]] = None, kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, @@ -290,7 +290,9 @@ def _functional_call( ) if kwargs is None: kwargs = {} - if not isinstance(args, tuple): + if args is None: + args = () + elif not isinstance(args, tuple): args = (args,) with _reparametrize_module( module, parameters_and_buffers, tie_weights=tie_weights, strict=strict diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 0dbc42aea33395..73979c014bf54b 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -36,18 +36,13 @@ "select_model_mode_for_export", "register_custom_op_symbolic", "unregister_custom_op_symbolic", - "disable_log", - "enable_log", - # Errors - "CheckerError", # Backwards compatibility + # Base error + "OnnxExporterError", # Dynamo Exporter "DiagnosticOptions", "ExportOptions", "ONNXProgram", - "ONNXProgramSerializer", "ONNXRuntimeOptions", - "InvalidExportOptionsError", - "OnnxExporterError", "OnnxRegistry", "dynamo_export", "enable_fake_mode", @@ -55,7 +50,7 @@ "is_onnxrt_backend_supported", ] -from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING +from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING import torch from torch import _C @@ -70,7 +65,7 @@ OrtExecutionProvider as _OrtExecutionProvider, ) from ._type_utils import JitScalarType -from .errors import CheckerError # Backwards compatibility +from .errors import OnnxExporterError from .utils import ( _optimize_graph, _run_symbolic_function, @@ -109,12 +104,8 @@ DiagnosticOptions, ExportOptions, ONNXProgram, - ONNXProgramSerializer, ONNXRuntimeOptions, - InvalidExportOptionsError, - OnnxExporterError, OnnxRegistry, - dynamo_export, enable_fake_mode, ) @@ -123,22 +114,19 @@ import os # Set namespace for exposed private names +DiagnosticOptions.__module__ = "torch.onnx" +ExportOptions.__module__ = "torch.onnx" ExportTypes.__module__ = "torch.onnx" JitScalarType.__module__ = "torch.onnx" -ExportOptions.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" -ONNXProgramSerializer.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" -dynamo_export.__module__ = "torch.onnx" -InvalidExportOptionsError.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" -enable_fake_mode.__module__ = "torch.onnx" OnnxRegistry.__module__ = "torch.onnx" -DiagnosticOptions.__module__ = "torch.onnx" -is_onnxrt_backend_supported.__module__ = "torch.onnx" -_OrtExecutionProvider.__module__ = "torch.onnx" -_OrtBackendOptions.__module__ = "torch.onnx" _OrtBackend.__module__ = "torch.onnx" +_OrtBackendOptions.__module__ = "torch.onnx" +_OrtExecutionProvider.__module__ = "torch.onnx" +enable_fake_mode.__module__ = "torch.onnx" +is_onnxrt_backend_supported.__module__ = "torch.onnx" producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION @@ -294,7 +282,9 @@ def forward(self, x): This is required for models with large weights that exceed the ONNX file size limit (2GB). When False, the weights are saved in the ONNX file with the model architecture. dynamic_shapes: A dictionary of dynamic shapes for the model inputs. Refer to - :func:`torch.export.export` for more details. + :func:`torch.export.export` for more details. This is only used (and preferred) when dynamo is True. + Only one parameter `dynamic_axes` or `dynamic_shapes` should be set + at the same time. report: Whether to generate a markdown report for the export process. verify: Whether to verify the exported model using ONNX Runtime. profile: Whether to profile the export process. @@ -374,6 +364,12 @@ def forward(self, x): else: from torch.onnx.utils import export + if dynamic_shapes: + raise ValueError( + "The exporter only supports dynamic shapes " + "through parameter dynamic_axes when dynamo=False." + ) + export( model, args, @@ -396,35 +392,126 @@ def forward(self, x): return None -# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. +def dynamo_export( + model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] + /, + *model_args, + export_options: ExportOptions | None = None, + **model_kwargs, +) -> ONNXProgram | Any: + """Export a torch.nn.Module to an ONNX graph. -# Returns True iff ONNX logging is turned on. -is_onnx_log_enabled = _C._jit_is_onnx_log_enabled + Args: + model: The PyTorch model to be exported to ONNX. + model_args: Positional inputs to ``model``. + model_kwargs: Keyword inputs to ``model``. + export_options: Options to influence the export to ONNX. + + Returns: + An in-memory representation of the exported ONNX model. + **Example 1 - Simplest export** + :: -def enable_log() -> None: - r"""Enables ONNX logging.""" - _C._jit_set_onnx_log_enabled(True) + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + def forward(self, x, bias=None): + out = self.linear(x) + out = out + bias + return out -def disable_log() -> None: - r"""Disables ONNX logging.""" - _C._jit_set_onnx_log_enabled(False) + model = MyModel() + kwargs = {"bias": 3.0} + args = (torch.randn(2, 2, 2),) + onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( + "my_simple_model.onnx" + ) + + **Example 2 - Exporting with dynamic shapes** + :: + + # The previous model can be exported with dynamic shapes + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_program = torch.onnx.dynamo_export( + model, *args, **kwargs, export_options=export_options + ) + onnx_program.save("my_dynamic_model.onnx") + """ -"""Sets output stream for ONNX logging. + # NOTE: The new exporter is experimental and is not enabled by default. + import warnings -Args: - stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported - as ``stream_name``. -""" -set_log_stream = _C._jit_set_onnx_log_output_stream + from torch.onnx import _flags + from torch.onnx._internal import exporter + from torch.utils import _pytree + if isinstance(model, torch.export.ExportedProgram): + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + elif _flags.USE_EXPERIMENTAL_LOGIC: + if export_options is not None: + warnings.warn( + "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " + "For a more comprehensive set of export options, including advanced features, please consider using " + "`torch.onnx.export(..., dynamo=True)`. ", + category=FutureWarning, + ) + + if export_options is not None and export_options.dynamic_shapes: + # Make all shapes dynamic + def _to_dynamic_shapes_mapper(): + arg_order = 0 + + def _to_dynamic_shape(x): + nonlocal arg_order + if isinstance(x, torch.Tensor): + rank = len(x.shape) + dynamic_shape = {} + for i in range(rank): + dynamic_shape[i] = torch.export.Dim( + f"arg_{arg_order}_dim_{i}" + ) + arg_order += 1 + return dynamic_shape + else: + return None + + return _to_dynamic_shape + + # model_args could be nested + dynamic_shapes = _pytree.tree_map( + _to_dynamic_shapes_mapper(), + model_args, + ) + else: + dynamic_shapes = None -"""A simple logging facility for ONNX exporter. + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + else: + from torch.onnx._internal._exporter_legacy import dynamo_export -Args: - args: Arguments are converted to string, concatenated together with a newline - character appended to the end, and flushed to output stream. -""" -log = _C._jit_onnx_log + return dynamo_export( + model, *model_args, export_options=export_options, **model_kwargs + ) diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index df1ae10c9c26e2..0222da61cfef48 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -1,7 +1,19 @@ # mypy: allow-untyped-defs -from __future__ import ( # for onnx.ModelProto (ONNXProgram) and onnxruntime (ONNXRuntimeOptions) - annotations, -) +from __future__ import annotations + + +__all__ = [ + "DiagnosticOptions", + "ExportOptions", + "ONNXProgram", + "ONNXRuntimeOptions", + "InvalidExportOptionsError", + "OnnxRegistry", + "UnsatisfiedDependencyError", + "dynamo_export", + "enable_fake_mode", +] + import abc import contextlib @@ -11,23 +23,13 @@ import tempfile import warnings from collections import defaultdict -from typing import ( - Any, - Callable, - Final, - Mapping, - Protocol, - runtime_checkable, - Sequence, - TYPE_CHECKING, - TypeVar, -) +from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar from typing_extensions import Self import torch import torch._ops -import torch.export as torch_export import torch.utils._pytree as pytree +from torch.onnx import errors from torch.onnx._internal import io_adapter from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.fx import ( @@ -46,11 +48,8 @@ import onnx - import onnxruntime # type: ignore[import] - import onnxscript # type: ignore[import] - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] - registration as torchlib_registry, - ) + import onnxruntime + import onnxscript from torch._subclasses import fake_tensor from torch.onnx._internal.fx import diagnostics @@ -110,10 +109,6 @@ def __init__(self) -> None: self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( defaultdict(list) ) - # FIXME: Avoid importing onnxscript into torch - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] # noqa: F401 - registration, - ) # opset_version is unused for now, since torchlib only supports opset18. # TODO: get opset version from torchlib @@ -123,8 +118,7 @@ def __init__(self) -> None: "different opset version, please register them with register_custom_op." ) - # Initialize registry from torchlib - self._initiate_registry_from_torchlib(registration.default_registry) + self._initiate_registry_from_torchlib() @property def opset_version(self) -> int: @@ -134,33 +128,25 @@ def opset_version(self) -> int: return self._opset_version - def _initiate_registry_from_torchlib( - self, torchlib_registry: torchlib_registry.Registry - ): + def _initiate_registry_from_torchlib(self) -> None: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ - for aten_name, aten_overloads_func in torchlib_registry.items(): - internal_name_instance = registration.OpName.from_qualified_name(aten_name) - for overload_func in aten_overloads_func.overloads: - symbolic_function = registration.ONNXFunction( - onnx_function=overload_func, - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=False, - ) - self._register(internal_name_instance, symbolic_function) - - for complex_func in aten_overloads_func.complex: - symbolic_function = registration.ONNXFunction( - onnx_function=complex_func, - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=True, - ) - self._register(internal_name_instance, symbolic_function) + import onnxscript._framework_apis.torch_2_5 as onnxscript_apis + + for meta in onnxscript_apis.get_torchlib_ops(): + internal_name_instance = registration.OpName.from_qualified_name( + meta.qualified_name + ) + symbolic_function = registration.ONNXFunction( + onnx_function=meta.function, # type: ignore[arg-type] + op_full_name=internal_name_instance.qualified_name(), + is_custom=False, + is_complex=meta.is_complex, + ) + self._register(internal_name_instance, symbolic_function) def _register( self, @@ -263,7 +249,6 @@ class ExportOptions: When ``None``, the exporter determines the most compatible setting. When ``True``, all input shapes are considered dynamic. When ``False``, all input shapes are considered static. - op_level_debug: Whether to export the model with op-level debug information diagnostic_options: The diagnostic options for the exporter. fake_context: The fake context used for symbolic tracing. onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. @@ -277,9 +262,6 @@ class ExportOptions: - ``False``: all input shapes are considered static. """ - op_level_debug: bool | None = None - """When True export the model with op-level debug running ops through ONNX Runtime.""" - diagnostic_options: DiagnosticOptions """The diagnostic options for the exporter.""" @@ -293,13 +275,11 @@ def __init__( self, *, dynamic_shapes: bool | None = None, - op_level_debug: bool | None = None, fake_context: ONNXFakeContext | None = None, onnx_registry: OnnxRegistry | None = None, diagnostic_options: DiagnosticOptions | None = None, ): self.dynamic_shapes = dynamic_shapes - self.op_level_debug = op_level_debug self.fake_context = fake_context self.onnx_registry = onnx_registry self.diagnostic_options = diagnostic_options or DiagnosticOptions() @@ -313,7 +293,6 @@ class ResolvedExportOptions(ExportOptions): # Public attributes MUST be redefined below without ``Optional[]`` from ``ExportOptions`` dynamic_shapes: bool - op_level_debug: bool diagnostic_options: DiagnosticOptions fake_context: ONNXFakeContext onnx_registry: OnnxRegistry @@ -337,28 +316,17 @@ class ResolvedExportOptions(ExportOptions): def __init__( self, options: ExportOptions | ResolvedExportOptions, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, # type: ignore[name-defined] + model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined] ): from torch.onnx._internal.fx import ( # TODO: Prevent circular dep diagnostics, dynamo_graph_extractor, - torch_export_graph_extractor, ) if isinstance(options, ResolvedExportOptions): self.dynamic_shapes = options.dynamic_shapes - self.op_level_debug = options.op_level_debug self.diagnostic_options = options.diagnostic_options self.fake_context = options.fake_context - # private - if isinstance(model, torch_export.ExportedProgram) and not isinstance( - options.fx_tracer, torch_export_graph_extractor.TorchExport - ): - message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer" - e = InvalidExportOptionsError(message) - raise InvalidExportOptionsError( - ONNXProgram._from_failure(e, options.diagnostic_context), message - ) self.fx_tracer = options.fx_tracer self.onnx_registry = options.onnx_registry self.onnxfunction_dispatcher = options.onnxfunction_dispatcher @@ -379,10 +347,8 @@ def resolve(value: T | None, fallback: T | Callable[[], T]) -> T: self.diagnostic_options = resolve( options.diagnostic_options, DiagnosticOptions() ) - if isinstance(model, torch_export.ExportedProgram): - self.fx_tracer = torch_export_graph_extractor.TorchExport() - else: - self.fx_tracer = dynamo_graph_extractor.DynamoExport() + + self.fx_tracer = dynamo_graph_extractor.DynamoExport() self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type] self.diagnostic_context = diagnostics.DiagnosticContext( @@ -400,7 +366,6 @@ def resolve(value: T | None, fallback: T | Callable[[], T]) -> T: from torch.onnx._internal.fx import onnxfunction_dispatcher - self.op_level_debug = resolve(options.op_level_debug, False) self.onnxfunction_dispatcher = ( onnxfunction_dispatcher.OnnxFunctionDispatcher( self.onnx_registry, @@ -430,8 +395,7 @@ def enable_fake_mode(): are too large to fit into memory. Returns: - A :class:`ONNXFakeContext` object that must be passed to :func:`dynamo_export` - through the :attr:`ExportOptions.fake_context` argument. + A :class:`ONNXFakeContext` object. Example:: @@ -449,15 +413,16 @@ def enable_fake_mode(): ... my_nn_module = MyModel() ... arg1 = torch.randn(2, 2, 2) # positional input 1 >>> export_options = torch.onnx.ExportOptions(fake_context=fake_context) - >>> onnx_program = torch.onnx.dynamo_export( - ... my_nn_module, - ... arg1, - ... export_options=export_options - ... ) + >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True) + >>> onnx_program.apply_weights(MyModel().state_dict()) >>> # Saving model WITHOUT initializers - >>> onnx_program.save("my_model_without_initializers.onnx") + >>> onnx_program.save( + ... "my_model_without_initializers.onnx", + ... include_initializers=False, + ... keep_initializers_as_inputs=True, + ... ) >>> # Saving model WITH initializers - >>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict()) + >>> onnx_program.save("my_model_with_initializers.onnx") .. warning:: This API is experimental and is *NOT* backward-compatible. @@ -487,97 +452,6 @@ def enable_fake_mode(): ) # type: ignore[assignment] -@runtime_checkable -class ONNXProgramSerializer(Protocol): - """Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf). - Note that this is an advanced usage scenario.""" - - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - """Protocol method that must be implemented for serialization. - - Args: - onnx_program: Represents the in-memory exported ONNX model - destination: A binary IO stream or pre-allocated buffer into which - the serialized model should be written. - - Example: - - A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf - format to ``destination``: - - :: - - # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import io - >>> import torch - >>> import torch.onnx - >>> class MyModel(torch.nn.Module): # Dummy model - ... def __init__(self) -> None: - ... super().__init__() - ... self.linear = torch.nn.Linear(2, 2) - ... def forward(self, x): - ... out = self.linear(x) - ... return out - >>> class ProtobufONNXProgramSerializer: - ... def serialize( - ... self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase - ... ) -> None: - ... destination.write(onnx_program.model_proto.SerializeToString()) - >>> model = MyModel() - >>> arg1 = torch.randn(2, 2, 2) # positional input 1 - >>> torch.onnx.dynamo_export(model, arg1).save( - ... destination="exported_model.onnx", - ... serializer=ProtobufONNXProgramSerializer(), - ... ) - """ - ... - - -class ProtobufONNXProgramSerializer: - """Serializes ONNX graph as Protobuf.""" - - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - import onnx - - if not isinstance(onnx_program.model_proto, onnx.ModelProto): # type: ignore[attr-defined] - raise ValueError("onnx_program.ModelProto is not an onnx.ModelProto") - destination.write(onnx_program.model_proto.SerializeToString()) - - -class LargeProtobufONNXProgramSerializer: - """Serializes ONNX graph as Protobuf. - - Fallback to serializing as Protobuf with external data for models larger than 2GB. - """ - - _destination_path: Final[str] # type: ignore[misc] - - def __init__(self, destination_path: str): - self._destination_path = destination_path - - def serialize( - self, onnx_program: ONNXProgram, destination: io.BufferedIOBase - ) -> None: - """`destination` is ignored. The model is saved to `self._destination_path` instead.""" - import onnx - - if onnx_program.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: - onnx.save_model(onnx_program.model_proto, self._destination_path) # type: ignore[attr-defined] - else: - # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB - # Fallback to serializing the model with external data. - onnx.save_model( # type: ignore[attr-defined] - onnx_program.model_proto, - self._destination_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - ) - - class ONNXRuntimeOptions: """Options to influence the execution of the ONNX model through ONNX Runtime. @@ -618,7 +492,6 @@ class ONNXProgram: diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata. fake_context: The fake context used for symbolic tracing. export_exception: The exception that occurred during export, if any. - model_signature: The model signature for the exported ONNX graph. """ _model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc] @@ -627,9 +500,8 @@ class ONNXProgram: _diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc] _fake_context: Final[ONNXFakeContext | None] # type: ignore[misc] _export_exception: Final[Exception | None] # type: ignore[misc] - _model_signature: Final[torch.export.ExportGraphSignature | None] # type: ignore[misc] _model_torch: Final[ # type: ignore[misc] - torch.nn.Module | Callable | torch_export.ExportedProgram | None + torch.nn.Module | Callable | None ] def __init__( @@ -641,28 +513,21 @@ def __init__( *, fake_context: ONNXFakeContext | None = None, export_exception: Exception | None = None, - model_signature: torch.export.ExportGraphSignature | None = None, - model_torch: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_torch: torch.nn.Module | Callable | None = None, ): self._model_proto = model_proto - self._model_signature = model_signature self._model_torch = model_torch self._input_adapter = input_adapter self._output_adapter = output_adapter self._diagnostic_context = diagnostic_context self._fake_context = fake_context self._export_exception = export_exception + self._state_dict: dict[str, torch.Tensor] = {} def __call__( self, *args: Any, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, options: ONNXRuntimeOptions | None = None, **kwargs: Any, ) -> Any: @@ -697,10 +562,8 @@ def __call__( onnx_model = os.path.join(tmpdir_path, "model.onnx") if isinstance(model_with_state_dict, torch.nn.Module): model_state = model_with_state_dict.state_dict() - elif isinstance(model_with_state_dict, torch_export.ExportedProgram): - model_state = model_with_state_dict.state_dict else: - model_state = None + model_state = self._state_dict self.save( onnx_model, model_state=model_state, @@ -734,104 +597,6 @@ def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined] raise self._export_exception return self._model_proto - @property - def model_signature(self) -> torch.export.ExportGraphSignature | None: - """The model signature for the exported ONNX graph. - - This information is relevant because ONNX specification often differs from PyTorch's, resulting - in a ONNX graph with input and output schema different from the actual PyTorch model implementation. - By using the model signature, the users can understand the inputs and outputs differences - and properly execute the model in ONNX Runtime. - - NOTE: Model signature is only available when the ONNX graph was exported from a - :class:`torch.export.ExportedProgram` object. - - NOTE: Any transformation done to the model that changes the model signature must be accompanied - by updates to this model signature as well through :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`. - - Example: - - The following model produces different sets of inputs and outputs. - The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight), - and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally - the last 2 inputs are user inputs (namely x and b). - The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output. - - >>> import pprint - >>> class CustomModule(torch.nn.Module): - ... def __init__(self) -> None: - ... super().__init__() - ... self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) - ... self.register_buffer("my_buffer1", torch.tensor(3.0)) - ... self.register_buffer("my_buffer2", torch.tensor(4.0)) - ... self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False) - ... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False) - ... self.fc1 = torch.nn.Linear(9216, 128, bias=False) - ... self.fc2 = torch.nn.Linear(128, 10, bias=False) - ... - ... def forward(self, x, b): - ... tensor_x = self.conv1(x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = self.conv2(tensor_x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2) - ... tensor_x = torch.flatten(tensor_x, 1) - ... tensor_x = self.fc1(tensor_x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = self.fc2(tensor_x) - ... output = torch.nn.functional.log_softmax(tensor_x, dim=1) - ... ( - ... self.my_buffer2.add_(1.0) + self.my_buffer1 - ... ) # Mutate buffer through in-place addition - ... return output - >>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3)) - >>> exported_program = torch.export.export( - ... CustomModule(), args=inputs - ... ).run_decompositions({}) - >>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs) - >>> pprint.pprint(onnx_program.model_signature) - ExportGraphSignature(input_specs=[InputSpec(kind=, - arg=TensorArgument(name='p_conv1_weight'), - target='conv1.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_conv2_weight'), - target='conv2.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_fc1_weight'), - target='fc1.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_fc2_weight'), - target='fc2.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='b_my_buffer2'), - target='my_buffer2', - persistent=True), - InputSpec(kind=, - arg=TensorArgument(name='b_my_buffer1'), - target='my_buffer1', - persistent=True), - InputSpec(kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='b'), - target=None, - persistent=None)], - output_specs=[OutputSpec(kind=, - arg=TensorArgument(name='add'), - target='my_buffer2'), - OutputSpec(kind=, - arg=TensorArgument(name='_log_softmax'), - target=None)]) - """ - - return self._model_signature - @property def diagnostic_context(self) -> diagnostics.DiagnosticContext: """The diagnostic context associated with the export.""" @@ -847,10 +612,7 @@ def fake_context(self) -> ONNXFakeContext | None: def adapt_torch_inputs_to_onnx( self, *model_args, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, **model_kwargs, ) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]: """Converts the PyTorch model inputs to exported ONNX model inputs format. @@ -920,10 +682,7 @@ def adapt_torch_inputs_to_onnx( def adapt_torch_outputs_to_onnx( self, model_outputs: Any, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, ) -> Sequence[torch.Tensor | int | float | bool]: """Converts the PyTorch model outputs to exported ONNX model outputs format. @@ -978,13 +737,19 @@ def adapt_torch_outputs_to_onnx( ), "model_with_state_dict must be specified." return self._output_adapter.apply(model_outputs, model=model_with_state_dict) # type: ignore[return-value] + def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Apply the weights from the specified state dict to the ONNX model. + Args: + state_dict: The state dict containing the weights to apply to the ONNX model. + """ + self._state_dict = state_dict + def save( self, destination: str | io.BufferedIOBase, *, include_initializers: bool = True, model_state: dict[str, Any] | str | None = None, - serializer: ONNXProgramSerializer | None = None, ) -> None: """Saves the in-memory ONNX model to ``destination`` using specified ``serializer``. @@ -1000,17 +765,15 @@ def save( It can be either a string with the path to a checkpoint or a dictionary with the actual model state. The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`. Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph. - serializer: The serializer to use. If not specified, the model will be serialized as Protobuf. """ + import onnx assert ( include_initializers is True or model_state is None ), "Cannot specify both `include_initializers=False` and `model_state`." - if serializer is None: - if isinstance(destination, str): - serializer = LargeProtobufONNXProgramSerializer(destination) - else: - serializer = ProtobufONNXProgramSerializer() + + if self._state_dict and model_state is None: + model_state = self._state_dict # Add initializers when symbolic tracing is enabled _model_state_files: list[str | io.BytesIO | dict[str, Any]] = [] @@ -1056,10 +819,20 @@ def save( else: if isinstance(destination, str): with open(destination, "wb") as f: - serializer.serialize(self, f) + if self.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT: + onnx.save_model(self.model_proto, destination) # type: ignore[attr-defined] + else: + # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB + # Fallback to serializing the model with external data. + onnx.save_model( # type: ignore[attr-defined] + self.model_proto, + destination, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) else: try: - serializer.serialize(self, destination) + destination.write(self.model_proto.SerializeToString()) except ValueError as exc: raise ValueError( "'destination' should be provided as a path-like string when saving a model larger than 2GB. " @@ -1168,7 +941,7 @@ class Exporter: def __init__( self, options: ResolvedExportOptions, - model: torch.nn.Module | Callable | torch_export.ExportedProgram, + model: torch.nn.Module | Callable, model_args: Sequence[Any], model_kwargs: Mapping[str, Any], ): @@ -1213,7 +986,6 @@ def export(self) -> ONNXProgram: onnxscript_graph = fx_interpreter.run( fx_graph_module=graph_module, onnxfunction_dispatcher=self.options.onnxfunction_dispatcher, - op_level_debug=self.options.op_level_debug, ) # NOTE: Filter out the initializers with fake tensors when it's fake_mode exporting. @@ -1257,9 +1029,6 @@ def export(self) -> ONNXProgram: self.options.fx_tracer.output_adapter, self.options.diagnostic_context, fake_context=self.options.fake_context, - model_signature=getattr( - self.model, "graph_signature", None - ), # Available for isinstance(self.model, ExportedProgram) only model_torch=self.model, ) @@ -1312,28 +1081,6 @@ def __init__(self, package_name: str, message: str): self.package_name = package_name -class OnnxExporterError(RuntimeError): - """Raised when an ONNX exporter error occurs. - - This exception is thrown when there's an error during the ONNX export process. - It encapsulates the :class:`ONNXProgram` object generated until the failure, allowing - access to the partial export results and associated metadata. - """ - - onnx_program: Final[ONNXProgram] # type: ignore[misc] - - def __init__(self, onnx_program: ONNXProgram, message: str): - """ - Initializes the OnnxExporterError with the given ONNX program and message. - - Args: - onnx_program (ONNXProgram): The partial results of the ONNX export. - message (str): The error message to be displayed. - """ - super().__init__(message) - self.onnx_program = onnx_program - - class InvalidExportOptionsError(RuntimeError): """Raised when user specified an invalid value for the :class:`ExportOptions`.""" @@ -1380,12 +1127,12 @@ def missing_opset(package_name: str): def dynamo_export( - model: torch.nn.Module | Callable | torch_export.ExportedProgram, # type: ignore[name-defined] + model: torch.nn.Module | Callable, /, *model_args, export_options: ExportOptions | None = None, **model_kwargs, -) -> ONNXProgram: +) -> ONNXProgram | Any: """Export a torch.nn.Module to an ONNX graph. Args: @@ -1483,10 +1230,7 @@ def forward(self, x, bias=None): "or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). " f"Please report a bug on PyTorch Github: {_PYTORCH_GITHUB_ISSUES_URL}" ) - raise OnnxExporterError( - ONNXProgram._from_failure(e, resolved_export_options.diagnostic_context), - message, - ) from e + raise errors.OnnxExporterError(message) from e def common_pre_export_passes( @@ -1564,18 +1308,3 @@ def common_pre_export_passes( ) return module - - -__all__ = [ - "DiagnosticOptions", - "ExportOptions", - "ONNXProgram", - "ONNXProgramSerializer", - "ONNXRuntimeOptions", - "InvalidExportOptionsError", - "OnnxExporterError", - "OnnxRegistry", - "UnsatisfiedDependencyError", - "dynamo_export", - "enable_fake_mode", -] diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 382bf7d78cd4cf..b12e53ef292624 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -30,6 +30,7 @@ def __getattr__(self, attr): if TYPE_CHECKING: import onnx import onnxscript + import onnxscript._framework_apis.torch_2_5 as onnxscript_apis onnxscript_ir = onnxscript.ir @@ -37,3 +38,4 @@ def __getattr__(self, attr): onnx = _LazyModule("onnx") onnxscript = _LazyModule("onnxscript") onnxscript_ir = _LazyModule("onnxscript.ir") + onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5") diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 43a6519482663f..2eb5adab3a7d9d 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -10,8 +10,6 @@ from collections import defaultdict from typing import TYPE_CHECKING -import onnxscript - import torch import torch._export.serde.schema from torch.export import graph_signature @@ -203,13 +201,7 @@ def analyze( model_info.outputs = outputs if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops # noqa: F401 - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) + registry = _registration.ONNXRegistry.from_torchlib() # Try to find ops for every node in the graph for node in exported_program.graph.nodes: diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 1b6057c6f2d85d..7729aa000b281e 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -13,14 +13,14 @@ import copy import inspect import logging -from typing import Any, Mapping, Sequence, TYPE_CHECKING, Union +from typing import Any, Iterable, Mapping, Sequence, TYPE_CHECKING, Union import onnxscript from onnxscript import evaluator, ir from onnxscript.ir import convenience as ir_convenience import torch -from torch.onnx._internal.exporter import _schemas, _tensors, errors +from torch.onnx._internal.exporter import _errors, _schemas, _tensors if TYPE_CHECKING: @@ -29,12 +29,13 @@ logger = logging.getLogger(__name__) -# TODO(justinchuby): Update ValidAttributeType to ir_convenience.SupportedAttrTypes ValidAttributeType = Union[ ir.TensorProtocol, int, float, bool, str, Sequence[int], Sequence[float], None ] -AllowedArgType = Union[ir.Value, Sequence[ir.Value], ValidAttributeType] +AllowedArgType = Union[ + ir.Value, Sequence[Union[ir.Value, ValidAttributeType]], ValidAttributeType +] # Logic for adapting inputs from general Python or PyTorch inputs to ONNX ir.Value @@ -181,13 +182,103 @@ def _resolve_parameter_dtypes( return type_binding -def _process_python_constants_and_sequences( +def _determine_input_dtype( + param: _schemas.Parameter, + arg: AllowedArgType, + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], +) -> ir.DataType: + """Determine the dtype of the input that is a mix of Python constants and ir.Value.""" + if param.type_constraint in type_binding: + # A known dtype is available because it was resolved + return type_binding[param.type_constraint].dtype + if len(param.type_constraint.allowed_types) == 1: + # Only one type is allowed by the type constraint + return next(iter(param.type_constraint.allowed_types)).dtype + + # No dtype information available. Infer from the Python constant or (in the Sequence case) + # from a mix of Python constants and ir.Value + if isinstance(arg, bool): + return ir.DataType.BOOL + if isinstance(arg, float): + return ir.DataType.FLOAT + if isinstance(arg, int): + return ir.DataType.INT64 + if isinstance(arg, str): + return ir.DataType.STRING + if isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + return arg.dtype + if arg is None: + return ir.DataType.UNDEFINED + + # Handle sequences + if isinstance(arg, (tuple, list)): + if len(arg) == 0: + # Special case: Treat empty sequence as INT64 as they are typically used for shape + return ir.DataType.INT64 + + # Try to obtain the dtype from one of the values + for val in arg: + if isinstance(val, ir.Value) and val.dtype is not None: + return val.dtype + + if any(isinstance(val, float) for val in arg): + # If any float is present, the dtype is float + return ir.DataType.FLOAT + elif any(isinstance(val, int) for val in arg): + # Otherwise if any int is present, the dtype is int + return ir.DataType.INT64 + + raise ValueError( + f"Could not determine the dtype for the input '{param.name}'. " + f"param={param}, arg={arg}, param_type_constraint={param.type_constraint}, " + f"type_binding={type_binding}" + ) + + +def _allowed_types_are_sequence_types(allowed_types: Iterable[ir.TypeProtocol]) -> bool: + """Check if all allowed types are Sequence types.""" + return all(isinstance(t, ir.SequenceType) for t in allowed_types) + + +def _get_or_create_constant( + constant_farm: dict[ + tuple[ + bool | int | float | str | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + arg: bool + | int + | float + | str + | tuple[int] + | tuple[float] + | tuple[bool] + | list[int] + | list[float] + | list[bool], + dtype: ir.DataType, + opset: onnxscript.values.Opset, +) -> ir.Value: + if isinstance(arg, list): + # Make the arg hashable + arg = tuple(arg) # type: ignore[assignment] + constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] + if constant_value is None: + constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] + return constant_value + + +def _process_python_constants( signature: _schemas.OpSignature, named_inputs: dict[str, AllowedArgType], type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], constant_farm: dict[ tuple[ - bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], + bool | int | float | str | tuple[int] | tuple[float], ir.DataType, ], ir.Value, @@ -206,7 +297,7 @@ def _process_python_constants_and_sequences( opset: The Opset to use for creating Constant nodes. Returns: - None + A mapping of parameter names to Python constants converted to constant Nodes. """ # 3. Convert Python constants to Constant nodes based on the dtype information; # construct sequences @@ -225,80 +316,128 @@ def _process_python_constants_and_sequences( if isinstance(arg, ir.Value): # TODO(justinchuby): Cast the ir.Value here if needed continue + if ( isinstance(arg, Sequence) and len(arg) > 0 - and all(isinstance(val, ir.Value) for val in arg) + and any(isinstance(val, ir.Value) for val in arg) ): # Skip the sequence of ir.Value. This is a variadic input or a Sequence input - # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants - # like `Max(0, ir.Value())` - # We need to convert the Python constants to Constant nodes - # NOTE: Important to check that arg is not empty because we need to treat it as list[int] or list[float] + # It will be handled by _process_python_sequences + continue + if param.variadic: + # Handled by _process_python_sequences + continue + if _allowed_types_are_sequence_types(param.type_constraint.allowed_types): + # Handled by _process_python_sequences continue - # if param.variadic: - # # FXIME: Handle variadic inputs and sequence inputs differently - # raise NotImplementedError - # TODO: Find a way to recursively build constants. Maybe extract the logic out. - # FIXME: I am here - - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" - if param.type_constraint in type_binding: - # A known dtype is available - dtype = type_binding[param.type_constraint].dtype - elif len(param.type_constraint.allowed_types) == 1: - # Only one type is allowed - dtype = next(iter(param.type_constraint.allowed_types)).dtype - else: - # No dtype information available. Infer from the Python constant - if isinstance(arg, bool): - dtype = ir.DataType.BOOL - elif isinstance(arg, float): - dtype = ir.DataType.FLOAT - elif isinstance(arg, int): - dtype = ir.DataType.INT64 - elif isinstance(arg, str): - dtype = ir.DataType.STRING - elif isinstance(arg, (tuple, list)) and all( - isinstance(val, int) for val in arg - ): - dtype = ir.DataType.INT64 - elif isinstance(arg, (tuple, list)) and any( - isinstance(val, float) for val in arg - ): - # NOTE: if any float is present, the dtype is float - dtype = ir.DataType.FLOAT - elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): - dtype = arg.dtype - elif arg is None: - dtype = ir.DataType.UNDEFINED - else: - raise TypeError( - f"Constant input '{arg}' of type '{type(arg)}' is not supported" - ) + dtype = _determine_input_dtype(param, arg, type_binding) if arg is None: constant_value = None - elif not isinstance(arg, (ir.Tensor, ir.TensorProtocol)): - # Deduplicate the constants - if isinstance(arg, (tuple, list)): - # Make the arg hashable - arg = tuple(arg) # noqa: PLW2901 - constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] - if constant_value is None: - constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] - constant_value = opset.Constant(value=constant_tensor) - constant_farm[(arg, dtype)] = constant_value # type: ignore[arg-type,index] - else: + elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): constant_value = opset.Constant(value=arg) + else: + # Deduplicate the constants + constant_value = _get_or_create_constant(constant_farm, arg, dtype, opset) # type: ignore[arg-type] named_inputs[param.name] = constant_value return named_inputs # type: ignore[return-value] +def _process_python_sequences( + signature: _schemas.OpSignature, + named_inputs: dict[str, AllowedArgType], + type_binding: Mapping[_schemas.TypeConstraintParam, ir.TypeProtocol], + constant_farm: dict[ + tuple[ + bool | int | float | str | ir.TensorProtocol | tuple[int] | tuple[float], + ir.DataType, + ], + ir.Value, + ], + opset: onnxscript.values.Opset, +): + """Handle three types of sequences. + + 1. Variadic inputs + 2. Sequence input of ir.Value, + 3. Sequence of Python constants that contains ir.Value + """ + for name, arg in named_inputs.items(): + param = signature.params_map[name] + assert isinstance( + param, _schemas.Parameter + ), f"Expected Parameter, got {type(param)}" + + if not isinstance(arg, (tuple, list)): + continue + + if len(arg) == 0: + # Skip empty sequences + continue + + # 1. Sequence input of ir.Value + if _allowed_types_are_sequence_types(param.type_constraint.allowed_types): + # Turn the list into a Sequence node + # Constant op creation will be handled by the variadic case below when calling + # the SequenceConstruct op. + named_inputs[name] = opset.SequenceConstruct(*arg) + continue + + # 2. Variadic inputs + # NOTE: Variadic operators like Max can be called with mixed ir.Value and Python constants + # like `Max(0, ir.Value())` + # We need to convert the Python constants to Constant nodes + if param.variadic: + if all(isinstance(val, ir.Value) for val in arg): + # Skip the variadic input if all values are ir.Value + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + new_args = [] + for val in arg: + if isinstance(val, ir.Value): + new_args.append(val) + else: + constant_tensor = ir.tensor(value=val, dtype=dtype) # type: ignore[arg-type] + constant_value = opset.Constant(value=constant_tensor) + new_args.append(constant_value) + named_inputs[name] = new_args + continue + else: + # 3. Concat the list as a single input + # E.g. [Value, 42] should be converted to op.Concat(Value, Constant(42)) + # when the expected input type is INT64 + # We assume this only happens for 1D cases + if all(isinstance(val, ir.Value) for val in arg): + named_inputs[name] = opset.Concat(*arg) + continue + + dtype = _determine_input_dtype(param, arg, type_binding) + new_args = [] + for val in arg: + if isinstance(val, ir.Value): + new_args.append(val) + elif val is None: + # Skip None values + continue + elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + new_args.append(opset.Constant(value=val)) + else: + # Turn the Python constant into 1D tensor for the constant + assert isinstance( + val, (bool, int, float) + ), f"Expected int or float, got {type(val)}" + new_args.append( + _get_or_create_constant(constant_farm, [arg], dtype, opset) # type: ignore[arg-type] + ) + named_inputs[name] = opset.Concat(*new_args) + continue + return named_inputs + + def _construct_node( signature: _schemas.OpSignature, named_inputs: Mapping[str, ir.Value | None], @@ -326,6 +465,12 @@ def _construct_node( else: inputs.append(value) + # If final inputs are None, strip them from the node inputs + for input in reversed(inputs): + if input is not None: + break + inputs.pop() + # Construct and filter out None attributes attributes = [ attr @@ -368,11 +513,19 @@ def _call_op( """ type_binding = _resolve_parameter_dtypes(op_signature, named_inputs) try: - converted_named_inputs = _process_python_constants_and_sequences( + converted_named_inputs = _process_python_constants( op_signature, named_inputs, type_binding, self.constant_farm, self.opset ) + converted_named_inputs = _process_python_sequences( + op_signature, + converted_named_inputs, # type: ignore[arg-type] + type_binding, + self.constant_farm, + self.opset, + ) + except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error processing Python constants for operator '{op_signature.domain}::{op_signature.name}'. " f"named_inputs={named_inputs}, named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." ) from e @@ -384,7 +537,7 @@ def _call_op( ) ) except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error constructing node for operator '{op_signature.domain}::{op_signature.name}'. " f"named_inputs={named_inputs}, converted_named_inputs={converted_named_inputs}, " f"named_attrs={named_attrs}, opset={self.opset}, op_signature={op_signature}." @@ -428,7 +581,7 @@ def eval( return outputs[0] return outputs except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error calling operator '{schema.name}' with args {args} and kwargs {kwargs}." ) from e @@ -494,6 +647,11 @@ def eval_function( # type: ignore[override] # call because it will filter out the unexpected kwargs for us. if function.traceable: # Trace the function call instead of adding the function as a node + # Turn the ir.Attr objects into Python constants first + named_attrs = { + name: attr.value if isinstance(attr, ir.Attr) else attr + for name, attr in named_attrs.items() + } return function.function(**named_inputs, **named_attrs) outputs = self._call_op(op_signature, named_inputs, named_attrs) @@ -508,7 +666,7 @@ def eval_function( # type: ignore[override] _, lineno = inspect.getsourcelines(function.function) except Exception: source_file = lineno = None - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error calling function '{function.name}' with args {args} and kwargs {kwargs}." + f" The function is defined at '{source_file}:{lineno}'." if source_file diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index dc511491d6b41b..4cec92854ea858 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -120,9 +120,22 @@ class TorchExportStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes - ) + try: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes + ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) @@ -148,9 +161,22 @@ class TorchExportNonStrictStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False - ) + try: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False + ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 642f768d7285cb..3fddef36b8b428 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -8,10 +8,8 @@ import logging from typing import Any, Mapping, Sequence, TYPE_CHECKING -import onnx - import torch -import torch.export +from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir from torch.onnx._internal.exporter import _core, _onnx_program @@ -30,7 +28,9 @@ def _signature(model) -> inspect.Signature: def _from_dynamic_axes_to_dynamic_shapes( model, + *, dynamic_axes=None, + output_names: set[str], input_names: Sequence[str] | None = None, ) -> dict[str, Any] | None: """ @@ -70,10 +70,13 @@ def _from_dynamic_axes_to_dynamic_shapes( # for the exported program dynamic_shapes_to_exported_program = {} for input_name, axes in dynamic_axes.items(): - # input_name can be either from inptu_names or from the model inputs + if input_name in output_names: + # User specified an output name as a dynamic axis, so we skip it + continue + # input_name can be either from input_names or from the model inputs if input_name not in input_names_to_model_inputs: raise ValueError( - f"dynamix axis: {input_name} is not found in the input names: {input_names}" + f"dynamic axis: {input_name} is not found in the input names: {input_names}" ) model_input_name = input_names_to_model_inputs[input_name] if isinstance(axes, dict): @@ -108,13 +111,6 @@ def _get_torch_export_args( return args, kwargs -def _convert_version(path: str | os.PathLike, opset_version: int) -> None: - """Convert the ONNX file to a specific version.""" - model = onnx.load(path, load_external_data=False) - model = onnx.version_converter.convert_version(model, opset_version) - onnx.save(model, path) - - def export_compat( model: torch.nn.Module | torch.export.ExportedProgram @@ -142,20 +138,25 @@ def export_compat( artifacts_dir: str | os.PathLike = ".", fallback: bool = False, **_, -) -> _onnx_program.ONNXProgram | None: +) -> _onnx_program.ONNXProgram: + if opset_version is None: + # TODO(justinchuby): Change the hardcoded opset version for it to be flexible + opset_version = 18 + if isinstance(model, torch.export.ExportedProgram): - # We the model is already exported program, so the args, kwargs, and dynamic_shapes + # We know the model is already exported program, so the args, kwargs, and dynamic_shapes # are not used dynamic_shapes = dynamic_shapes or {} else: args, kwargs = _get_torch_export_args(args, kwargs) if dynamic_shapes is None and dynamic_axes is not None: dynamic_shapes = _from_dynamic_axes_to_dynamic_shapes( - model, dynamic_axes, input_names + model, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=set(output_names or ()), ) - should_convert_version = False - try: onnx_program = _core.export( model, @@ -173,20 +174,6 @@ def export_compat( verbose=verbose, ) - if f is not None: - # Always save the initializers as external data to reduce the size of the ONNX file - onnx_program.save( - f, - include_initializers=export_params, - keep_initializers_as_inputs=keep_initializers_as_inputs, - external_data=external_data, - ) - if ( - opset_version is not None - and opset_version != onnx_program.model.opset_imports.get("") - ): - should_convert_version = True - except Exception as e: if fallback: if verbose is not False: @@ -194,6 +181,8 @@ def export_compat( "[torch.onnx] Falling back to legacy torch.onnx.export due " f"to the following error: {e}", ) + if f is None: + raise TypeError("f must be provided when fallback is enabled") from e torch.onnx.utils.export( model, # type: ignore[arg-type] args, @@ -206,20 +195,22 @@ def export_compat( dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, ) - onnx_program = None - if opset_version is None: - opset_version = 18 - if opset_version != 17: - should_convert_version = True + onnx_program = _onnx_program.ONNXProgram(ir.load(f), None) else: raise - if f is not None and should_convert_version: - assert opset_version is not None - if verbose is not False: - print( - f"[torch.onnx] Converting the ONNX file to opset version {opset_version}..." - ) - _convert_version(f, opset_version) + # Converter opset version and optimize + onnx_program.model = onnxscript_apis.convert_version( + onnx_program.model, opset_version + ) + onnx_program.model = onnxscript_apis.optimize(onnx_program.model) + + if f is not None: + onnx_program.save( + f, + include_initializers=export_params, + keep_initializers_as_inputs=keep_initializers_as_inputs, + external_data=external_data, + ) return onnx_program diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index e775571d583327..09fae0ad2b88ee 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -14,40 +14,35 @@ import typing from typing import Any, Callable, Literal, Sequence -import onnx - import onnxscript import onnxscript.evaluator -import onnxscript.function_libs -import onnxscript.function_libs.torch_lib -import onnxscript.function_libs.torch_lib.registration from onnxscript import ir from onnxscript.ir import convenience as ir_convenience import torch import torch.fx from torch.export import graph_signature +from torch.onnx._internal._lazy_import import onnxscript_apis from torch.onnx._internal.exporter import ( _analysis, _building, _capture_strategies, _dispatching, + _errors, _fx_passes, _ir_passes, - _isolated, _onnx_program, _registration, _reporting, _tensors, _verification, - errors, ) if typing.TYPE_CHECKING: import os - import numpy as np + import numpy.typing as npt # Define utilities to convert PyTorch data types so users do not need to specify manually @@ -68,6 +63,9 @@ torch.int64: ir.DataType.INT64, torch.int8: ir.DataType.INT8, torch.uint8: ir.DataType.UINT8, + torch.uint16: ir.DataType.UINT16, + torch.uint32: ir.DataType.UINT32, + torch.uint64: ir.DataType.UINT64, } _BLUE = "\033[96m" _END = "\033[0m" @@ -102,10 +100,10 @@ def __init__(self, tensor: torch.Tensor, name: str | None = None): tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype), name=name ) - def __array__(self, dtype: Any = None) -> np.ndarray: - # numpy() calls __array__ in ir.Tensor + def numpy(self) -> npt.NDArray: + self.raw: torch.Tensor if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).__array__(dtype) + return self.raw.view(torch.uint16).numpy(force=True) if self.dtype in { ir.DataType.FLOAT8E4M3FN, ir.DataType.FLOAT8E4M3FNUZ, @@ -113,13 +111,27 @@ def __array__(self, dtype: Any = None) -> np.ndarray: ir.DataType.FLOAT8E5M2FNUZ, }: # TODO: Use ml_dtypes - return self.raw.view(torch.uint8).__array__(dtype) - return self.raw.__array__(dtype) + return self.raw.view(torch.uint8).numpy(force=True) + return self.raw.numpy(force=True) + + def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: + del copy # Unused, but needed for the signature + if dtype is None: + return self.numpy() + return self.numpy().__array__(dtype) def tobytes(self) -> bytes: # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array + import torch._subclasses.fake_tensor + + if isinstance(self.raw, torch._subclasses.fake_tensor.FakeTensor): + raise TypeError( + f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " + "with a tensor backed by real data using ONNXProgram.apply_weights() " + "or save the model without initializers by setting include_initializers=False." + ) tensor = self.raw.detach().cpu().contiguous() return bytes( (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( @@ -165,7 +177,11 @@ def _set_shape_types( def _set_shape_type( value: ir.Value, - meta_val: torch.Tensor | tuple[torch.Tensor], + meta_val: torch.Tensor + | torch.SymBool + | torch.SymInt + | torch.SymFloat + | tuple[torch.Tensor], complex_to_float: bool, ) -> None: # TODO: Consider using meta["tensor_meta"] for this? Would it be faster? @@ -425,7 +441,7 @@ def _handle_call_function_node_with_lowering( if onnx_function is None: # TODO(justinchuby): Fall back to ATen op or do something else? - raise errors.DispatchError( + raise _errors.DispatchError( f"No ONNX function found for {node.target!r}. Failure message: {message}" ) @@ -452,7 +468,7 @@ def _handle_call_function_node_with_lowering( try: outputs = onnx_function(*onnx_args, **onnx_kwargs) except Exception as e: - raise errors.GraphConstructionError( + raise _errors.GraphConstructionError( f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" ) from e @@ -546,7 +562,7 @@ def _add_nodes( # No lowering _handle_call_function_node(model.graph, node, node_name_to_values) except Exception as e: - raise errors.OnnxConversionError( + raise _errors.ConversionError( f"Error when translating node {node.format_node()}. See the stack trace for more information." ) from e return node_name_to_values @@ -691,13 +707,7 @@ def exported_program_to_ir( registry: The registry of all ONNX Script decomposition. """ if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops # noqa: F401 - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) + registry = _registration.ONNXRegistry.from_torchlib() if lower != "none": exported_program = _prepare_exported_program_for_export( exported_program, registry=registry @@ -768,7 +778,7 @@ def _exported_program_to_onnx_program( }, ), ir_version=9, - producer_name="torch", + producer_name="pytorch", producer_version=torch.__version__, ) @@ -963,7 +973,7 @@ def export( Raises: TorchExportError: If the export process fails with torch.export. - OnnxConversionError: If the ExportedProgram to ONNX translation fails. + ConversionError: If the ExportedProgram to ONNX translation fails. """ # Set up the error reporting facilities timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f") @@ -981,6 +991,8 @@ def export( program: torch.export.ExportedProgram | None = None # Step 1: Export the model with torch.export.export if the model is not already an ExportedProgram if isinstance(model, torch.export.ExportedProgram): + # We know the model is already exported program, so the args, kwargs, and dynamic_shapes + # are not used. program = model export_status.torch_export = True else: @@ -1042,7 +1054,7 @@ def export( # focus on the torch.export.export error. Errors from other strategies like # torch.jit.trace is due to the fallback and can be confusing to users. # We save all errors in the error report. - raise errors.TorchExportError( + raise _errors.TorchExportError( _STEP_ONE_ERROR_MESSAGE + ( f"\nError report has been saved to '{report_path}'." @@ -1071,13 +1083,7 @@ def export( try: # Build the ONNX function registry if registry is None: - # Trigger op registration - from onnxscript.function_libs.torch_lib import ops - - del ops - registry = _registration.ONNXRegistry.from_torchlib( - onnxscript.function_libs.torch_lib.registration.default_registry # type: ignore[arg-type] - ) + registry = _registration.ONNXRegistry.from_torchlib() # Process the exported program to run decompositions and type promotions etc. decomposed_program = _prepare_exported_program_for_export( @@ -1108,7 +1114,7 @@ def export( else: report_path = None - raise errors.OnnxConversionError( + raise _errors.ConversionError( _STEP_TWO_ERROR_MESSAGE + (f"\nError report has been saved to '{report_path}'." if report else "") + _summarize_exception_stack(e) @@ -1172,7 +1178,7 @@ def export( else: report_path = None - raise errors.OnnxConversionError( + raise _errors.ConversionError( _STEP_TWO_ERROR_MESSAGE + (f"\nError report has been saved to '{report_path}'." if report else "") + _summarize_exception_stack(e) @@ -1197,11 +1203,12 @@ def export( if not failed_results else _format_exceptions_for_all_strategies(failed_results), onnx_program.exported_program, - profile_result=profile_result, - export_status=export_status, decomp_comparison=_reporting.format_decomp_comparison( pre_decomp_unique_ops, post_decomp_unique_ops ), + export_status=export_status, + profile_result=profile_result, + model=onnx_program.model, registry=registry, ) verbose_print(f"Export report has been saved to '{report_path}'.") @@ -1214,28 +1221,13 @@ def export( # Step 3: (verify=True) Check the ONNX model with ONNX checker try: - verbose_print("Run `onnx.checker` on the ONNX model...") - - # TODO: Handle when model is >2GB - - model_proto = onnx_program.model_proto - byte_size = model_proto.ByteSize() - if byte_size < 2 * 1024 * 1024 * 1024: - # The checker may segfault so we need to run it in a separate process - _isolated.safe_call( - onnx.checker.check_model, # type:ignore[attr-defined] - onnx_program.model_proto, - full_check=True, - ) - export_status.onnx_checker = True - verbose_print("Run `onnx.checker` on the ONNX model... ✅") - else: - verbose_print( - f"Run `onnx.checker` on the ONNX model... ⚠️ Skipped because model is too large ({byte_size})." - ) + verbose_print("Check the ONNX model...") + onnxscript_apis.check_model(onnx_program.model) + export_status.onnx_checker = True + verbose_print("Check the ONNX model... ✅") except Exception as e: export_status.onnx_checker = False - verbose_print("Run `onnx.checker` on the ONNX model... ❌") + verbose_print("Check the ONNX model... ❌") if report: try: assert pre_decomp_unique_ops is not None diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py index 3bbff757e92ef2..de0dd0c0bcb73b 100644 --- a/torch/onnx/_internal/exporter/_decomp.py +++ b/torch/onnx/_internal/exporter/_decomp.py @@ -1,8 +1,7 @@ -"""Build decomp table from PyTorch.""" - # mypy: allow-untyped-defs from __future__ import annotations +import itertools from typing import Callable, TYPE_CHECKING import torch @@ -40,33 +39,6 @@ def get_onnx_implemented_overloads( return registered_ops -def get_preserve_ops() -> set[torch._ops.OpOverload]: - """Return a set of CompositeImplicitAutograd ops that should be preserved.""" - aten = torch.ops.aten - # NOTE: Keep this list sorted - # NOTE: Do _not_ retain aten.linear as its decomposition is addmm, which is Gemm and is preferable for accuracy - return { - aten._upsample_bilinear2d_aa.default, - aten._upsample_nearest_exact1d.vec, - aten._upsample_nearest_exact2d.vec, - aten._upsample_nearest_exact3d.vec, - aten.group_norm.default, - aten.instance_norm.default, - aten.upsample_bilinear2d.default, - aten.upsample_bilinear2d.vec, - aten.upsample_linear1d.default, - aten.upsample_linear1d.vec, - aten.upsample_nearest1d.default, - aten.upsample_nearest1d.vec, - aten.upsample_nearest2d.default, - aten.upsample_nearest2d.vec, - aten.upsample_nearest3d.default, - aten.upsample_nearest3d.vec, - aten.upsample_trilinear3d.default, - aten.upsample_trilinear3d.vec, - } - - def create_onnx_friendly_decomposition_table( onnx_registered_ops: set[torch._ops.OperatorBase], ) -> dict[torch._ops.OperatorBase, Callable]: @@ -85,16 +57,19 @@ def create_onnx_friendly_decomposition_table( """ decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} - # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single - # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your - # definitions in a single TORCH_LIBRARY block. - for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): # type: ignore[attr-defined] + for op_overload, decomp_fn in itertools.chain( + torch._decomp._decomp_table_to_post_autograd_aten().items(), # type: ignore[attr-defined] + torch._decomp.decomposition_table.items(), # type: ignore[attr-defined] + ): # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX # symbolic function. # NOTE: Do not skip torch._refs decomps. They are fine because otherwise the model is # not exportable anyways. if op_overload in onnx_registered_ops: continue + # If it is HOP, we filter those out as well. + if not hasattr(op_overload, "_schema"): + continue decomposition_table[op_overload] = decomp_fn return decomposition_table diff --git a/torch/onnx/_internal/exporter/_dispatching.py b/torch/onnx/_internal/exporter/_dispatching.py index b8aecfaa93793a..11ed1af17aaade 100644 --- a/torch/onnx/_internal/exporter/_dispatching.py +++ b/torch/onnx/_internal/exporter/_dispatching.py @@ -2,9 +2,8 @@ from __future__ import annotations import logging -from typing import Sequence +from typing import Callable, Sequence -import onnxscript from onnxscript import ir import torch @@ -163,10 +162,22 @@ def _param_type_compatible_with_arg( def _get_type_from_tensor( - tensor: torch.Tensor | Sequence[torch.Tensor], + tensor: torch.Tensor + | torch.SymBool + | torch.SymInt + | torch.SymFloat + | Sequence[torch.Tensor], ) -> ir.TypeProtocol: if isinstance(tensor, torch.Tensor): return ir.TensorType(_torch_dtype_to_onnx_compatible_dtype(tensor.dtype)) + if isinstance(tensor, torch.SymBool): + return ir.TensorType(ir.DataType.BOOL) + if isinstance(tensor, torch.SymInt): + return ir.TensorType(ir.DataType.INT64) + if isinstance(tensor, torch.SymFloat): + return ir.TensorType(ir.DataType.FLOAT) + + # Handle sequences first_tensor = next((item for item in tensor if item is not None), None) if first_tensor is None: return ir.SequenceType(ir.TensorType(ir.DataType.UNDEFINED)) @@ -189,7 +200,7 @@ def _get_first_tensor_in_node_list( def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argument]: - # FIXME: node.target may not have a schema + assert hasattr(node.target, "_schema") torch_schema: torch.FunctionSchema = node.target._schema # type: ignore[union-attr] node_args = {} for arg, schema_arg in zip(node.args, torch_schema.arguments): @@ -201,8 +212,8 @@ def _get_named_fx_node_args(node: torch.fx.Node) -> dict[str, torch.fx.node.Argu def get_matching_overload( node: torch.fx.Node, - overloads: Sequence[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction], -) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: + overloads: Sequence[Callable], +) -> tuple[Callable | None, str]: """Get the overload that matches the node's arguments. Args: @@ -212,8 +223,14 @@ def get_matching_overload( Returns: A tuple containing the matched overload and a string describing the reason for failure or success. """ + if not hasattr(node.target, "_schema"): + # FIXME(justinchuby): When the target is a builtin, we should instead + # Match only the inputs positionally. Figure out how to do that as right + # now we assume all inputs are named. + return overloads[ + 0 + ], "The node target does not have a schema. Return the first one." named_args = _get_named_fx_node_args(node) - # FIXME: node.target may and builtin and not have a schema # FIXME: Handle when we don't know the names of the arguments schema_args: dict[str, torch.Argument] = { arg.name: arg @@ -308,7 +325,7 @@ def _arg_has_complex_dtype(arg) -> bool: def dispatch( node: torch.fx.Node, registry: _registration.ONNXRegistry -) -> tuple[onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction | None, str]: +) -> tuple[Callable | None, str]: """Dispatch a node to an ONNX function based on the node's target and the ONNX registry. Args: diff --git a/torch/onnx/_internal/exporter/_errors.py b/torch/onnx/_internal/exporter/_errors.py new file mode 100644 index 00000000000000..ff41bbe695fe7d --- /dev/null +++ b/torch/onnx/_internal/exporter/_errors.py @@ -0,0 +1,21 @@ +"""Error classes for the ONNX exporter.""" + +from __future__ import annotations + +import torch.onnx.errors + + +class TorchExportError(torch.onnx.errors.OnnxExporterError): + """Error during graph capturing using torch.export.""" + + +class ConversionError(torch.onnx.errors.OnnxExporterError): + """Error during ExportedProgram to ONNX conversion.""" + + +class DispatchError(ConversionError): + """Error during ONNX Function dispatching.""" + + +class GraphConstructionError(ConversionError): + """Error during ONNX graph construction.""" diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py index bcab0d595b433e..c1083b65f8fcea 100644 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -17,11 +17,7 @@ def decompose_with_registry( """ onnx_registered_ops = set(_decomp.get_onnx_implemented_overloads(registry)) decomp_table = _decomp.create_onnx_friendly_decomposition_table(onnx_registered_ops) - # Try to preserve some known CompositeImplicitAutograd ops - to_preserve = _decomp.get_preserve_ops() - # We can only preserve implemented ops - can_preserve = tuple(to_preserve.intersection(onnx_registered_ops)) - return exported_program.run_decompositions(decomp_table, _preserve_ops=can_preserve) + return exported_program.run_decompositions(decomp_table) def insert_type_promotion_nodes( @@ -39,7 +35,10 @@ def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.Graph """Remove all assertion and check nodes from the FX graph""" aten_assertion_targets = { torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.default, torch.ops.aten._assert_async.msg, + torch.ops.aten._assert_scalar.default, + torch.ops.aten._assert_tensor_metadata.default, } for node in graph_module.graph.nodes: if node.op == "call_function" and node.target in aten_assertion_targets: diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 646b2e7d494f1d..6fb62a5961f875 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -5,26 +5,28 @@ __all__ = ["ONNXProgram"] +import copy import gc import logging import os -import pathlib import tempfile import textwrap -from typing import Callable, IO, Sequence, TYPE_CHECKING +import warnings +from typing import Callable, Sequence, TYPE_CHECKING import torch -from torch.onnx._internal import _lazy_import -from torch.utils import _pytree as pytree +from torch.onnx._internal._lazy_import import onnx, onnxscript_apis, onnxscript_ir as ir +from torch.utils import _pytree -onnx = _lazy_import.onnx -ir = _lazy_import.onnxscript_ir - +# NOTE: DO NOT import module from torch.onnx._internal to this module in the global scope +# because ONNXProgram is exposed to the public API if TYPE_CHECKING: import onnxruntime as ort +_LARGE_MODEL_THRESHOLD = 1536 * 1024 * 1024 # 1536MB + logger = logging.getLogger(__name__) @@ -47,6 +49,15 @@ def _ort_session_initializer(model: str | bytes) -> ort.InferenceSession: ) +def _count_initializer_size(graph: ir.Graph) -> int: + """Count the total size of the initializers in bytes.""" + return sum( + v.const_value.nbytes + for v in graph.initializers.values() + if v.const_value is not None + ) + + class ONNXProgram: """A class to represent an ONNX program that is callable with torch tensors.""" @@ -105,7 +116,7 @@ def model_proto(self) -> onnx.ModelProto: def save( self, - destination: str | os.PathLike | IO[bytes], + destination: str | os.PathLike, *, include_initializers: bool = True, keep_initializers_as_inputs: bool = False, @@ -117,6 +128,22 @@ def save( When `external_data` is `True` or the model is larger than 2GB, the weights are saved as external data in a separate file. + Initializer (model weights) serialization behaviors: + - include_initializers=True, keep_initializers_as_inputs=False (default): + The initializers are included in the saved model. + - include_initializers=True, keep_initializers_as_inputs=True: + The initializers are included in the saved model and kept as model inputs. + Choose this option if you want the ability to override the model weights + during inference. + - include_initializers=False, keep_initializers_as_inputs=False: + The initializers are not included in the saved model and are not listed + as model inputs. Choose this option if you want to attach the initializers + to the ONNX model in a separate, post-processing, step. + - include_initializers=False, keep_initializers_as_inputs=True: + The initializers are not included in the saved model but are listed as model + inputs. Choose this option if you want to supply the initializers during + inference and want to minimize the size of the saved model. + Args: destination: The path to save the ONNX model to. include_initializers: Whether to include the initializers in the saved model. @@ -128,44 +155,49 @@ def save( Raises: TypeError: If `external_data` is `True` and `destination` is not a file path. """ + original_initializers = copy.copy(self.model.graph.initializers) + original_inputs = copy.copy(self.model.graph.inputs) + + # Adjust the model based on options if not include_initializers: self.model.graph.initializers.clear() - logger.warning( - "The initializers have been removed from the model. This is destructive. " - "Developers: Please implement ir.Model copy() and remove initializers on the copied model." - ) if keep_initializers_as_inputs: - self.model.graph.inputs.extend(self.model.graph.initializers.values()) # type: ignore[arg-type] - logger.warning( - "The initializers have been added as inputs to the model. This is destructive. " - "Developers: Please implement ir.Model copy() and remove initializers on the copied model." - ) - proto = ir.serde.serialize_model(self.model) - byte_size = proto.ByteSize() - model_too_large = (byte_size) >= 1 << 31 - if external_data or model_too_large: - # TODO: Create an IR pass to handle external tensors conversion - if model_too_large: - logger.warning( - "The serialized ONNX model is larger than 2GB (%s). " - "Saving the weights as external data in a separate file.", - byte_size, + self.model.graph.inputs.extend(original_initializers.values()) # type: ignore[arg-type] + + # Save the model to disk + if ( + external_data + or _count_initializer_size(self.model.graph) > _LARGE_MODEL_THRESHOLD + ): + onnxscript_apis.save_model_with_external_data(self.model, destination) + else: + ir.save(self.model, destination) + + # Revert the changes to the model + if not include_initializers: + self.model.graph.initializers.update(original_initializers) + if keep_initializers_as_inputs: + self.model.graph.inputs.clear() + self.model.graph.inputs.extend(original_inputs) + + def apply_weights(self, state_dict: dict[str, torch.Tensor]) -> None: + """Apply the weights from the specified state dict to the ONNX model. + Args: + state_dict: The state dict containing the weights to apply to the ONNX model. + """ + from torch.onnx._internal.exporter import _core + + for name, tensor in state_dict.items(): + if name in self.model.graph.initializers: + self.model.graph.initializers[name].const_value = _core.TorchTensor( + tensor, name ) - if not isinstance(destination, (str, os.PathLike)): - raise TypeError( - "Saving the weights as external data is only supported when destination is a file path" + else: + warnings.warn( + f"Weight '{name}' not found in the model. Skipped applying.", + category=torch.onnx.errors.OnnxExporterWarning, + stacklevel=1, ) - destination_path = pathlib.Path(destination) - # Create the directory if it does not exist - data_path = f"{destination_path.name}.data" - onnx.save_model( - proto, - destination, - save_as_external_data=True, - location=data_path, - ) - else: - onnx.save_model(proto, destination) def initialize_inference_session( self, @@ -182,27 +214,17 @@ def initialize_inference_session( """ # TODO(justinchuby): Allow different inference options logger.debug("Initializing the inference session.") - proto = ir.serde.serialize_model(self.model) - byte_size = proto.ByteSize() - model_too_large = (byte_size) >= 1 << 31 - - if model_too_large: - logger.debug( - "The serialized ONNX model is larger than 2GB (%s).", byte_size - ) + if ( + byte_size := _count_initializer_size(self.model.graph) + ) > _LARGE_MODEL_THRESHOLD: + logger.debug("The model initializers is larger than 1.5GB (%s).", byte_size) # Save the model to a temporary file if too large self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) model_path = os.path.join(self._tempdir.name, "model.onnx") - data_path = "model.onnx.data" - onnx.save_model( - proto, - model_path, - save_as_external_data=True, - location=data_path, - ) + self.save(model_path, external_data=True) model = model_path else: - model = proto.SerializeToString() # type: ignore[assignment] + model = self.model_proto.SerializeToString() # type: ignore[assignment] self._inference_session = initializer(model) logger.debug("Inference session initialized.") @@ -231,7 +253,7 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]: def _flatten_inputs(model_args, model_kwargs): - flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs)) + flattened_args, _ = _pytree.tree_flatten((model_args, model_kwargs)) return flattened_args diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index b649188c264eac..0afa084b06c80d 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -18,17 +18,17 @@ import operator import types import typing -from typing import Callable, Literal, Mapping, Union +from typing import Callable, Literal, Union from typing_extensions import TypeAlias import torch import torch._ops +from torch.onnx._internal._lazy_import import onnxscript_apis from torch.onnx._internal.exporter import _schemas if typing.TYPE_CHECKING: import onnxscript - from onnxscript.function_libs.torch_lib import registration as torchlib_registration _DEFAULT_OPSET_VERSION = 18 @@ -49,7 +49,7 @@ class OnnxDecompMeta: device: The device the function is registered to. If None, it is registered to all devices. """ - onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction + onnx_function: Callable fx_target: TorchOp is_custom: bool = False is_complex: bool = False @@ -135,24 +135,21 @@ def opset_version(self) -> int: return self._opset_version @classmethod - def from_torchlib( - cls, - torchlib_registry: Mapping[str, torchlib_registration.OverloadedFunction] - | None = None, - ) -> ONNXRegistry: + def from_torchlib(cls) -> ONNXRegistry: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ registry = cls() - if torchlib_registry is None: - from onnxscript.function_libs.torch_lib import ( - registration as torchlib_registration, - ) - torchlib_registry = torchlib_registration.default_registry # type: ignore[assignment] - for qualified_name, aten_overloads_func in torchlib_registry.items(): # type: ignore[union-attr] + torchlib_ops = onnxscript_apis.get_torchlib_ops() + + for meta in torchlib_ops: + qualified_name = meta.qualified_name + overload_func = meta.function + domain = meta.domain + name = meta.name try: # NOTE: This is heavily guarded with try-except because we don't want # to fail the entire registry population if one function fails. @@ -162,33 +159,19 @@ def from_torchlib( target = _get_overload(qualified_name) if target is None: continue - for overload_func in aten_overloads_func.overloads: - overload_func.signature = _schemas.OpSignature.from_function( - overload_func, - overload_func.function_ir.domain, - overload_func.name, - ) - onnx_decomposition = OnnxDecompMeta( - onnx_function=overload_func, - fx_target=target, - is_custom=False, - is_complex=False, - ) - registry._register(target, onnx_decomposition) - - for complex_func in aten_overloads_func.complex: - overload_func.signature = _schemas.OpSignature.from_function( - overload_func, - overload_func.function_ir.domain, - overload_func.name, - ) - onnx_decomposition = OnnxDecompMeta( - onnx_function=complex_func, - fx_target=target, - is_custom=False, - is_complex=True, - ) - registry._register(target, onnx_decomposition) + + overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] + overload_func, + domain, + name, + ) + onnx_decomposition = OnnxDecompMeta( + onnx_function=overload_func, + fx_target=target, + is_custom=False, + is_complex=meta.is_complex, + ) + registry._register(target, onnx_decomposition) except Exception: logger.exception("Failed to register '%s'. Skipped", qualified_name) continue diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py index 12e2f3d2f44df2..fbc406053fbee8 100644 --- a/torch/onnx/_internal/exporter/_schemas.py +++ b/torch/onnx/_internal/exporter/_schemas.py @@ -121,6 +121,8 @@ def has_default(self) -> bool: @dataclasses.dataclass(frozen=True) class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + name: str type: ir.AttributeType required: bool @@ -435,7 +437,7 @@ def from_function( # https://github.com/python/cpython/issues/102405 type_hints = typing.get_type_hints(func) - params = [] + params: list[Parameter | AttributeParameter] = [] # Create a mapping from type to a unique name type_constraints: dict[str, TypeConstraintParam] = {} @@ -446,8 +448,19 @@ def from_function( param.name, py_signature, ) - type_constraints[param.name] = TypeConstraintParam.any_value( - f"T_{param.name}" + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraints[param.name] = type_constraint + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) ) else: type_ = type_hints[param.name] @@ -490,7 +503,7 @@ def from_function( type_constraints[type_constraint_name] = type_constraint # 4. Create Parameter params.append( - Parameter( # type: ignore[arg-type] + Parameter( name=param.name, type_constraint=type_constraint, required=param.default is inspect.Parameter.empty, diff --git a/torch/onnx/_internal/exporter/_tensors.py b/torch/onnx/_internal/exporter/_tensors.py index cfe8f7dc2a661a..2fdafacbe06f4b 100644 --- a/torch/onnx/_internal/exporter/_tensors.py +++ b/torch/onnx/_internal/exporter/_tensors.py @@ -88,9 +88,6 @@ def __lt__(self, other): def __le__(self, other): return self._opset.LessOrEqual(self, other) - def __eq__(self, other): - return self._opset.Equal(self, other) - def __ge__(self, other): return self._opset.GreaterOrEqual(self, other) diff --git a/torch/onnx/_internal/exporter/_verification.py b/torch/onnx/_internal/exporter/_verification.py index 10560414616fbf..c4eec16da4990f 100644 --- a/torch/onnx/_internal/exporter/_verification.py +++ b/torch/onnx/_internal/exporter/_verification.py @@ -91,7 +91,10 @@ def verify_onnx_program( ) abs_diff = abs_diff.flatten() rel_diff = rel_diff.flatten() - bins = torch.tensor([0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10]) + bins = torch.tensor( + [0.0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10, 1000000], + dtype=abs_diff.dtype, + ) abs_diff_hist = torch.histogram(abs_diff, bins=bins) rel_diff_hist = torch.histogram(rel_diff, bins=bins) results.append( diff --git a/torch/onnx/_internal/exporter/errors.py b/torch/onnx/_internal/exporter/errors.py deleted file mode 100644 index a70eccf3a5633e..00000000000000 --- a/torch/onnx/_internal/exporter/errors.py +++ /dev/null @@ -1,30 +0,0 @@ -class ExporterError(RuntimeError): - """Error during export.""" - - -class TorchExportError(ExporterError): - """Error during torch.export.export.""" - - -class OnnxConversionError(ExporterError): - """Error during ONNX conversion.""" - - -class DispatchError(OnnxConversionError): - """Error during ONNX Funtion dispatching.""" - - -class GraphConstructionError(OnnxConversionError): - """Error during graph construction.""" - - -class OnnxCheckerError(ExporterError): - """Error during ONNX model checking.""" - - -class OnnxRuntimeError(ExporterError): - """Error during ONNX Runtime execution.""" - - -class OnnxValidationError(ExporterError): - """Output value mismatch.""" diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index 388ae29cb699ed..5246788756f3ed 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -176,7 +176,7 @@ class Transform(abc.ABC): One important aspect to note is that if the transformation modifies the model input and/or output signature, (e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep` - are needed to reconcile :attr:`ONNXProgram.model_signature` and :attr:`ONNXProgram.model_proto`. + are needed to reconcile :attr:`ONNXProgram.model_proto`. That is, the model signature and the model representation must match. As an additional feature, this class provides builtin support for transformation recording using the diagnostics. diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index 01b851a49421d3..b81c7254751ba8 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -5,7 +5,6 @@ import logging import operator import re -import types from typing import Callable, Sequence import onnxscript # type: ignore[import] @@ -20,7 +19,6 @@ _pass, diagnostics, onnxfunction_dispatcher, - op_validation, type_utils as fx_type_utils, ) from torch.utils import _pytree @@ -396,7 +394,6 @@ def run_node( node, fx_graph_module: torch.fx.GraphModule, onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, fx_name_to_onnxscript_value: dict[ @@ -411,7 +408,6 @@ def run_node( node: The FX node to be translated. fx_graph_module: The FX graph module containing the node. onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. - op_level_debug (bool): Whether to enable op level debug. onnxscript_graph: The ONNX graph to be populated. onnxscript_tracer: The tracer to trace the ONNX graph. fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. @@ -446,7 +442,6 @@ def run_node( onnxscript_tracer, fx_name_to_onnxscript_value, onnxfunction_dispatcher, - op_level_debug, fx_graph_module, ) elif node.op == "call_method": @@ -459,7 +454,6 @@ def run_node( onnxscript_tracer, fx_graph_module, onnxfunction_dispatcher, - op_level_debug, ) elif node.op == "output": self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) @@ -474,7 +468,6 @@ def run( self, fx_graph_module: torch.fx.GraphModule, onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph | None = None, ) -> onnxscript_graph_building.TorchScriptGraph: @@ -483,7 +476,6 @@ def run( Args: fx_graph_module: FX graph module to be translated. onnxfunction_dispatcher: ONNX function dispatcher. - op_level_debug: Whether to enable op-level debug. parent_onnxscript_graph: The parent TorchScript graph. Must be provided if `fx_graph_module` is a submodule. If not provided, `fx_graph_module` is assumed to be the root module. @@ -541,13 +533,11 @@ def run( # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible # with FakeTensorMode. with torch.utils._mode_utils.no_dispatch(): - # node_fixed_shape is only used on op_level_debug purpose. for node in fx_graph_module.graph.nodes: self.run_node( node, fx_graph_module, onnxfunction_dispatcher, - op_level_debug, onnxscript_graph, onnxscript_tracer, fx_name_to_onnxscript_value, @@ -621,7 +611,6 @@ def call_function( | tuple[onnxscript_graph_building.TorchScriptTensor, ...], ], onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, fx_graph_module: torch.fx.GraphModule, ): # aten ops and other stateless functions. @@ -675,23 +664,6 @@ def call_function( assert isinstance( output, (onnxscript_graph_building.TorchScriptTensor, tuple) ), type(output) - # NOTE(titaiwang): We bypass two kinds of ops as it's not meaningful to - # validate them with op level debug. - # 1. aten::sym_size: The op is simply get item from a list of tensors. - # 2. BuiltinFunction: It doesn't supported tensor - if ( - op_level_debug - and node.target != torch.ops.aten.sym_size - and not isinstance(node.target, types.BuiltinFunctionType) - ): - op_validation.validate_op_between_ort_torch( - self.diagnostic_context, - node, - symbolic_fn, - fx_args, - fx_kwargs, - fx_graph_module, - ) fx_name_to_onnxscript_value[node.name] = output def output( @@ -734,7 +706,6 @@ def call_module( tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, root_fx_graph_module: torch.fx.GraphModule, onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - op_level_debug: bool, ) -> None: """Export a fx.GraphModule submodule to ONNXScript graph. @@ -753,7 +724,6 @@ def call_module( tracer: The tracer used to trace the ONNXScript graph. root_fx_graph_module: The root FX module. onnxfunction_dispatcher: The dispatcher. - op_level_debug: Whether to enable op-level debug. """ assert isinstance( node.target, str @@ -766,7 +736,7 @@ def call_module( ), f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." sub_onnxscript_graph = self.run( - sub_module, onnxfunction_dispatcher, op_level_debug, parent_onnxscript_graph + sub_module, onnxfunction_dispatcher, parent_onnxscript_graph ) onnx_args, _ = _wrap_fx_args_as_onnxscript_args( diff --git a/torch/onnx/_internal/fx/op_validation.py b/torch/onnx/_internal/fx/op_validation.py deleted file mode 100644 index 834e1157b20d4d..00000000000000 --- a/torch/onnx/_internal/fx/op_validation.py +++ /dev/null @@ -1,381 +0,0 @@ -# mypy: allow-untyped-defs -"""Module for handling op-level validation during exporting.""" - -from __future__ import annotations - -import logging -from typing import Any, Callable, Sequence - -import onnxscript # type: ignore[import] -from onnxscript import evaluator # type: ignore[import] - -import torch -import torch.fx -from torch.fx.experimental import symbolic_shapes -from torch.onnx import _constants, _type_utils as jit_type_utils -from torch.onnx._internal.fx import ( - diagnostics, - fx_onnx_interpreter, - type_utils as fx_type_utils, -) -from torch.utils import _pytree - - -def _op_level_debug_message_formatter( - fn: Callable, - self, - node: torch.fx.Node, - symbolic_fn: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - *args, - **kwargs, -) -> str: - return ( - f"FX Node: {node.op}::{node.target}[name={node.name}]. \n" - f"ONNX Node: {symbolic_fn.name}[opset={symbolic_fn.opset}]." - ) - - -@diagnostics.diagnose_call( - diagnostics.rules.op_level_debugging, - diagnostic_message_formatter=_op_level_debug_message_formatter, -) -def validate_op_between_ort_torch( - diagnostic_context: diagnostics.DiagnosticContext, - node: torch.fx.Node, - symbolic_fn: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - fx_args: list[fx_type_utils.Argument], - fx_kwargs: dict[str, fx_type_utils.Argument], - fx_graph_module: torch.fx.GraphModule, -): - """Validate the op between ONNX Runtime and PyTorch. - - The function will run the op in ONNX Runtime and PyTorch and compare the - results. It doesn't break the exporting process, but saves each op validated - result into SARIF, under the section of `fx_onnx_interpreter`. - - There are three signs can be found: - 1. Blue: Pass - 2. Yellow: Bypass - - Args: - node (torch.fx.Node): The validated fx.node - symbolic_fn (Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]): The corresponded ONNX node - torch_args (list): torch argument inputs - torch_kwargs (dict): torch keyword argument inputs - fx_graph_module (torch.fx.GraphModule): The fx.GraphModule that contains the nodes - """ - # op-level validation - # Symbolic_fn should have the same output as node.target (torch ops) - - try: - torch_args, torch_kwargs = _wrap_fx_args_as_torch_args( - fx_args, fx_kwargs, fx_graph_module - ) - except ValueError as value_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section( - logging.WARNING, "Op level debug fails due to unsupported input types" - ): - diagnostic.log_source_exception(logging.WARNING, value_error) - diagnostic.level = diagnostics.levels.WARNING - return - - with evaluator.default_as(evaluator.ort_evaluator): - try: - expected_outputs = node.target(*torch_args, **torch_kwargs) # type: ignore[operator] - # NOTE: randomly generating indices/dim: INT64 could go out of bounds - except IndexError as index_error: - # TODO(titaiwang): How to bound indices/dim: INT64 - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section(logging.WARNING, "Op level debug is bypassed"): - diagnostic.log_source_exception(logging.WARNING, index_error) - diagnostic.level = diagnostics.levels.WARNING - return - # NOTE: Error in torch ops with random inputs generated from FakTensors - except RuntimeError as runtime_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section( - logging.WARNING, "Op level debug fails on PyTorch" - ): - diagnostic.log_source_exception(logging.WARNING, runtime_error) - diagnostic.level = diagnostics.levels.WARNING - return - - try: - ( - function_eager_inputs, - function_eager_attributes, - ) = _convert_torch_args_to_onnxfunction_args( - symbolic_fn.param_schemas(), - torch_args, - torch_kwargs, - allow_extra_kwargs=True, - ) - # NOTE: Apply kwargs preprocessing AFTER they are split - function_eager_attributes = ( - fx_onnx_interpreter.filter_incompatible_and_dtype_convert_kwargs( - function_eager_attributes - ) - ) - # NOTE: Incompatible kwargs or missing required args - except TypeError as type_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section(logging.WARNING, "Op level debug is bypassed"): - diagnostic.log_source_exception(logging.WARNING, type_error) - diagnostic.level = diagnostics.levels.WARNING - return - try: - ort_outputs = symbolic_fn( - *function_eager_inputs, **function_eager_attributes - ) - # NOTE: Error in ONNX Runtime with random inputs generated from FakTensors - except RuntimeError as runtime_error: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section( - logging.WARNING, "Op level debug fails on ONNXRUNTIME" - ): - diagnostic.log_source_exception(logging.WARNING, runtime_error) - diagnostic.level = diagnostics.levels.WARNING - return - - flattened_torch_outputs, _ = _pytree.tree_flatten(expected_outputs) - flattened_function_outputs, _ = _pytree.tree_flatten(ort_outputs) - - assert flattened_torch_outputs - assert len(flattened_torch_outputs) == len(flattened_function_outputs) - - for torch_output, function_output in zip( - flattened_torch_outputs, flattened_function_outputs - ): - if isinstance( - torch_output, torch.Tensor - ) and fx_type_utils.is_torch_complex_dtype(torch_output.dtype): - torch_output = torch.view_as_real(torch_output.resolve_conj()) - try: - if isinstance(function_output, onnxscript.tensor.Tensor): - function_output = function_output.value - - # Use torch.testing as opposed to np.testing to ensure dtypes and shapes match - torch.testing.assert_close( - torch.tensor(function_output).cpu(), - torch_output.cpu() - if isinstance(torch_output, torch.Tensor) - else torch.tensor(torch_output).cpu(), - rtol=1e-4, - atol=1e-3, - ) - except AssertionError as e: - diagnostic = diagnostic_context.inflight_diagnostic() - with diagnostic.log_section(logging.WARNING, "Validation failed"): - diagnostic.log_source_exception(logging.WARNING, e) - diagnostic.level = diagnostics.levels.WARNING - - -def _convert_symint_to_int_in_shape(shape: torch.Size) -> torch.Size: - """Convert SymInt to int in shape - - Args: - shape (torch.Size): The shape of a tensor - Raises: - ValueError: When SymInt is found in shape - Returns: - torch.Size: The shape of a tensor with SymInt converted to int - - """ - list_int_shape = [] - for dim in shape: - if isinstance(dim, torch.SymInt): - if symbolic_shapes.has_hint(dim): - list_int_shape.append(symbolic_shapes.hint_int(dim)) - else: - raise ValueError( - f"An unbacked SymInt found in shape. SymInt: {dim}; " - f"torch.Size: {shape}. There is no hint for SymInt." - ) - else: - list_int_shape.append(dim) - return torch.Size(list_int_shape) - - -def generate_random_tensors(shape: torch.Size, dtype: torch.dtype): - shape = _convert_symint_to_int_in_shape(shape) - - if dtype == torch.uint8: - return torch.randint( - low=_constants.UINT8_MIN, high=_constants.UINT8_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int8: - return torch.randint( - low=_constants.INT8_MIN, high=_constants.INT8_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int16: - return torch.randint( - low=_constants.INT16_MIN, high=_constants.INT16_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int32: - return torch.randint( - low=_constants.INT32_MIN, high=_constants.INT32_MAX, size=shape, dtype=dtype - ) - if dtype == torch.int64: - return torch.randint( - low=_constants.INT64_MIN, high=_constants.INT64_MAX, size=shape, dtype=dtype - ) - if dtype == torch.bool: - random_numbers = torch.rand(shape) - return torch.where( - random_numbers > 0.5, torch.tensor(True), torch.tensor(False) - ) - if fx_type_utils.is_torch_complex_dtype(dtype): - # ONNX does not support complex values, but supports their real representation - return torch.view_as_complex( - torch.randn((*shape, 2), dtype=fx_type_utils.from_complex_to_float(dtype)) - ) - return torch.randn(shape, dtype=dtype) - - -def _fx_args_to_torch_args( - fx_args: list[fx_type_utils.Argument], fx_graph_module: torch.fx.GraphModule -) -> list[fx_type_utils.Argument]: - """Recursively convert fx args to torch args""" - wrapped_args: list[fx_type_utils.Argument] = [] - for arg in fx_args: - if isinstance(arg, torch.fx.Node): - fake_tensor = arg.meta.get("val") - if fake_tensor is None and arg.op == "get_attr": - fake_tensor = getattr(fx_graph_module, arg.target) # type: ignore[operator, arg-type] - # NOTE: Currently, we are aware of - # FakeTensor/Tensor/SymInt/SymFloat/Symbool/int/float/bool could be in - # arg.meta["val"]/get_attr. - if isinstance(fake_tensor, torch.Tensor): - real_tensor = generate_random_tensors( - fake_tensor.shape, fake_tensor.dtype - ) - wrapped_args.append(real_tensor) - elif isinstance(fake_tensor, (int, float, bool)): - wrapped_args.append(fake_tensor) - elif symbolic_shapes.has_hint(fake_tensor): # type: ignore[arg-type] - wrapped_args.append(symbolic_shapes.hint_int(fake_tensor)) # type: ignore[arg-type] - else: - raise ValueError( - f"Unexpected input argument type found inside fx.Node. arg: {arg}; " - f"arg.meta['val']/get_attr: {fake_tensor}; type(arg.meta['val']/get_attr): " - f"{type(fake_tensor)}." - ) - elif isinstance(arg, Sequence): - wrapped_args.append(_fx_args_to_torch_args(arg, fx_graph_module)) # type: ignore[arg-type] - elif isinstance(arg, (int, float, torch.dtype)) or arg is None: - wrapped_args.append(arg) - elif isinstance(arg, torch.device): - wrapped_args.append(str(arg)) - else: - raise ValueError( - f"Unexpected input argument type is found in node arguments. arg: {arg}; " - ) - - return wrapped_args - - -def _wrap_fx_args_as_torch_args( - fx_args: list[fx_type_utils.Argument], - fx_kwargs: dict[str, fx_type_utils.Argument], - fx_graph_module: torch.fx.GraphModule, -) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: - """Prepare torch format args and kwargs for op-level validation by using fake tensor to create real tensor to feed in ops""" - - # NOTE: This function only supports FakeTensor with concrete shapes - torch_args: list[fx_type_utils.Argument] = _fx_args_to_torch_args( - fx_args, fx_graph_module - ) - return torch_args, fx_kwargs - - -# NOTE: Referenced from onnxscript internal function: _tag_arguments_with_param_schemas. -def _convert_torch_args_to_onnxfunction_args( - param_schemas: Sequence[onnxscript.values.ParamSchema], - args: list[fx_type_utils.Argument], - kwargs: dict[str, fx_type_utils.Argument], - allow_extra_kwargs: bool = False, -) -> tuple[list[Any], dict[str, Any]]: - """Convert Python args and kwargs to OnnxFunction acceptable with matching ONNX ParamSchema. - - NOTE: This is different from the param_schema separating in dispatcher, since at this point - we are already sure that the args and kwargs are in order and matched. - - Args: - param_schemas: The parameter schemas of an Op or a OnnxFunction. - args: The Python positional arguments supplied by the caller. - kwargs: The Python keyword arguments supplied by the caller. - allow_extra_kwargs: Whether to allow extra keyword arguments. - When set to True, extra/unknown arguments will be ignored. - - Returns: - A tuple of two elements: - - A list of Python positional argument. - - An ordered dictionary of Python keyword argument names and its values. - - Raises: - TypeError: When allow_extra_kwargs is False and there are unknown kwargs. - TypeError: When a required input is not provided. - """ - # args, kwargs and param_schemas should be all in order - # user may not specify all inputs or attributes - - all_param_names = {param.name for param in param_schemas} - extra_kwargs = set(kwargs).difference(all_param_names) - if extra_kwargs and not allow_extra_kwargs: - raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'") - - tagged_args: list[Any] = [] - tagged_kwargs: dict[str, Any] = {} - - for i, param in enumerate(param_schemas): - if param.is_variadic_input: - # Exhaust all remaining args - tagged_args.extend(arg for arg in args[i:]) - args = [] - continue - if i < len(args): - if param.is_input or isinstance(args[i], torch.dtype): - tagged_args.append(_convert_tensor_to_numpy(args[i])) - else: - tagged_args.append(args[i]) - elif param.name in kwargs: - if param.is_input: - tagged_kwargs[param.name] = _convert_tensor_to_numpy(kwargs[param.name]) - else: - tagged_kwargs[param.name] = kwargs[param.name] - elif param.required: - raise TypeError(f"Required input/attribute '{param}' was not provided") - - return tagged_args, tagged_kwargs - - -def _convert_tensor_to_numpy(input: fx_type_utils.Argument) -> Any: - try: - import numpy as np - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - f"{__name__} needs numpy, but it's not installed." - ) from exc - - if isinstance(input, torch.Tensor): - if torch.is_complex(input): - # from complex to real representation - input = torch.view_as_real(input.resolve_conj()) - return input.detach().cpu().numpy() - if isinstance(input, torch.dtype): - return int(jit_type_utils.JitScalarType.from_dtype(input).onnx_type()) # type: ignore[union-attr,call-overload] - if isinstance(input, (tuple, list)): - if len(input) == 0: - return np.array((), dtype=np.int64) - if isinstance(input[0], torch.Tensor): - return [_convert_tensor_to_numpy(x) for x in input] - if isinstance(input[0], bool): - return np.array(input, dtype=np.bool_) - - # Just a sequence of numbers - if isinstance(input[0], int): - return np.array(input, dtype=np.int64) - if isinstance(input[0], float): - return np.array(input) - return input diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 6397beb5f089a4..81cb6ccb7439d9 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -554,6 +554,9 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ), + ElementwiseTypePromotionRule( + "aten", "dot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), ElementwiseTypePromotionRule( "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), @@ -870,10 +873,7 @@ def preview_type_promotion( "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), ElementwiseTypePromotionRule( - "aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ), - ElementwiseTypePromotionRule( - "aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + "aten", "normal", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), ElementwiseTypePromotionRule( "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT @@ -924,9 +924,6 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ), - ElementwiseTypePromotionRule( - "aten", "rsub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ), ElementwiseTypePromotionRule( "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), @@ -1030,6 +1027,9 @@ def preview_type_promotion( ElementwiseTypePromotionRule( "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ), + ElementwiseTypePromotionRule( + "aten", "vdot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), ElementwiseTypePromotionRule( "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH ), diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py deleted file mode 100644 index aff9c154cd9e4f..00000000000000 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ /dev/null @@ -1,128 +0,0 @@ -# mypy: allow-untyped-defs -# NOTE: This file is referenced by name at -# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. -# introduced by https://github.com/pytorch/pytorch/pull/98894. -# If this file is renamed, moved, etc please update the reference there! - -from __future__ import annotations - -from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING - -import torch._dynamo -import torch.fx -from torch.onnx._internal import _exporter_legacy, io_adapter -from torch.onnx._internal.diagnostics import infra - - -if TYPE_CHECKING: - import torch.onnx - from torch.export.exported_program import ExportedProgram - - -class TorchExport(_exporter_legacy.FXGraphExtractor): - """Generates a FX GraphModule using torch.export API - Args: - aten_graph: If True, exports a graph with ATen operators. - If False, exports a graph with Python operators. - """ - - def __init__( - self, - aten_graph: bool | None = None, - ): - super().__init__() - self.aten_graph = aten_graph or True - - def generate_fx( - self, - options: _exporter_legacy.ResolvedExportOptions, - model: ExportedProgram, # type: ignore[override] - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - # No need to translate callable to FX graph. - # This FX Graph extractor assumes `model` was obtained through - # exported_program = torch.export.export( - # model, - # args=model_args, # type: ignore[arg-type] - # kwargs=model_kwargs, # type: ignore[arg-type] - # ) - - # Export FX graph to ONNX ModelProto. - self.input_adapter.append_step( - io_adapter.FlattenInputWithTreeSpecValidationInputStep() - ) - self.input_adapter.append_step( - io_adapter.PrependParamsBuffersConstantAotAutogradInputStep() - ) - - # ONNX does not support None inputs. During graph building, all None inputs - # are removed. Here we register this step to input adapter. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) - - # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 - # Dynamo doesn't support non-tensor inputs. - options.fx_tracer.input_adapter.append_step( - io_adapter.RemoveNonTensorInputStep() - ) - - # ONNX does not support complex inputs. During graph building, all complex inputs - # are converted to real representation inputs. Here we register this step to - # input/output adapter. - options.fx_tracer.input_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationInputStep() - ) - - updated_model_args = self.input_adapter.apply( - *model_args, model=model, **model_kwargs - ) - - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) - - # Output post-processing steps should happen after `FlattenOutputStep`. - options.fx_tracer.output_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationOutputStep() - ) - - options.fx_tracer.output_adapter.append_step( - io_adapter.PrependParamsAndBuffersAotAutogradOutputStep() - ) - - # run_decomposition generates a new graph module with decomposed ops. - # Thus, we need to run this step after io_adapters. - model = model.run_decompositions(options.decomposition_table) - - # Export FX graph to ONNX ModelProto. - return self.pre_export_passes( # type: ignore[return-value] - options, model, model.graph_module, updated_model_args - ) - - def pre_export_passes( - self, - options: _exporter_legacy.ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - # TODO: Import here to prevent circular dependency - from torch.onnx._internal.fx import analysis, passes - - diagnostic_context = options.diagnostic_context - - # ONNX does not support concept of (implicit) type promotion. - # Insert type casts explicitly where needed. - fx_module = passes.InsertTypePromotion(diagnostic_context, fx_module).run() - - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, fx_module, options.onnxfunction_dispatcher - ).analyze(infra.levels.ERROR) - - # This operation should be invoked as the last pre export pass. - # See [NOTE: Modularize pass ordering] - fx_module = passes.Modularize( - diagnostic_context, fx_module, is_exported_program=True - ).run() - - return fx_module diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py index 97626dcc70bb58..6a8fd4d4fefaf3 100644 --- a/torch/onnx/_internal/onnxruntime.py +++ b/torch/onnx/_internal/onnxruntime.py @@ -34,32 +34,18 @@ if TYPE_CHECKING: import onnx - -try: - # Use try-except to initialize package-dependent global variables. - import onnxruntime # type: ignore[import] - from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import] - - # This is not use directly in DORT but needed by underlying exporter, - # so we still need to check if it exists. - importlib.import_module("onnxscript") + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC import torch.onnx import torch.onnx._internal import torch.onnx._internal._exporter_legacy import torch.onnx._internal.diagnostics import torch.onnx._internal.fx.decomposition_table - import torch.onnx._internal.fx.passes - from torch.onnx._internal.fx import fx_onnx_interpreter - from torch.onnx._internal.fx.type_utils import ( - _TORCH_DTYPE_TO_NUMPY_DTYPE, - _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, - from_python_type_to_onnx_tensor_element_type, - ) + import torch.onnx._internal.fx.passes # noqa: TCH004 - _SUPPORT_ONNXRT = True -except ImportError: - _SUPPORT_ONNXRT = False + +_SUPPORT_ONNXRT: Optional[bool] = None __all__ = [ "is_onnxrt_backend_supported", @@ -87,6 +73,35 @@ def is_onnxrt_backend_supported() -> bool: ... print("pip install onnx onnxscript onnxruntime") ... """ + global _SUPPORT_ONNXRT + + if _SUPPORT_ONNXRT is None: + # `onnxruntime` might import a lot of other runtime packages, + # e.g. apex, deepspeed, transformers. + # So lazy-importing onnxruntime to avoid possible circular import. + try: + importlib.import_module("onnxruntime") + importlib.import_module("onnxruntime.capi._pybind_state") + + # This is not use directly in DORT but needed by underlying exporter, + # so we still need to check if it exists. + importlib.import_module("onnxscript") + + import torch.onnx # noqa: F401 + import torch.onnx._internal # noqa: F401 + import torch.onnx._internal._exporter_legacy # noqa: F401 + import torch.onnx._internal.diagnostics # noqa: F401 + from torch.onnx._internal.fx import ( # noqa: F401 + decomposition_table, + fx_onnx_interpreter, + passes, + type_utils, + ) + + _SUPPORT_ONNXRT = True + except ImportError: + _SUPPORT_ONNXRT = False + return _SUPPORT_ONNXRT @@ -143,6 +158,8 @@ def _nvtx_range_pop(): def _get_ort_device_type(device_type: str): + from onnxruntime.capi import _pybind_state as ORTC + if device_type == "cuda": return ORTC.OrtDevice.cuda() if device_type == "cpu": @@ -305,6 +322,8 @@ def _get_onnx_devices( ..., ], ) -> Tuple["ORTC.OrtDevice", ...]: + from onnxruntime.capi import _pybind_state as ORTC + def _device_id_or_zero(device_id: int) -> int: return device_id or 0 @@ -338,6 +357,10 @@ def _map_tensor_or_sym_to_device( def _get_ortvalues_from_torch_tensors( tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...] ) -> Tuple[torch.Tensor, ...]: + from onnxruntime.capi import _pybind_state as ORTC + + from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE + ortvalues = ORTC.OrtValueVector() ortvalues.reserve(len(tensors)) dtypes = [] @@ -436,6 +459,9 @@ def _run_onnx_session_with_ortvaluevector( ..., ], ) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + from onnxruntime.capi import _pybind_state as ORTC + _nvtx_range_push("contiguous") inputs = tuple( _adjust_scalar_from_fx_to_onnx(arg, value_info) @@ -514,6 +540,8 @@ def _run_onnx_session_with_fetch( ..., ], ) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: + import onnxruntime + inputs = tuple( _adjust_scalar_from_fx_to_onnx(arg, value_info) for arg, value_info in zip(inputs, input_value_infos) @@ -570,6 +598,11 @@ def __init__( ) def is_supported(self, *args): + from torch.onnx._internal.fx.type_utils import ( + _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, + from_python_type_to_onnx_tensor_element_type, + ) + # Compare the args and the input schema in ONNX model and # return the first match. if len(args) != len(self.input_value_infos): @@ -728,6 +761,12 @@ class OrtBackend: """ def __init__(self, options: Optional[OrtBackendOptions] = None): + from onnxruntime.capi import _pybind_state as ORTC + + import torch.onnx + import torch.onnx._internal._exporter_legacy + import torch.onnx._internal.fx.decomposition_table + self._options: Final = OrtBackendOptions() if options is None else options # options.export_options contains information shared between exporter and DORT. @@ -849,6 +888,10 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar it means we delegate the computation to _ort_acclerated_call and therefore onnxruntime.InferenceSession. """ + import onnxruntime + + from torch.onnx._internal.fx import fx_onnx_interpreter, passes + cached_execution_info_per_session = ( self._all_ort_execution_info.search_reusable_session_execution_info( graph_module, *args @@ -867,7 +910,7 @@ def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwar # It's first time seeing such as graph. Let's make a new session # (type: onnxruntime.InferenceSession) for it. - graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront( + graph_module = passes.MovePlaceholderToFront( self._resolved_onnx_exporter_options.diagnostic_context, graph_module, ).run() @@ -915,7 +958,7 @@ def maybe_map_to_meta_val(value): # Cast FX variables if they will result schema-mismatch when searching # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, # but ONNX expects add(double_tensor, double_tensor). - graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( + graph_module = passes.InsertTypePromotion( self._resolved_onnx_exporter_options.diagnostic_context, graph_module ).run() # Start the per-node exporting process. It's conceptually a for loop @@ -923,7 +966,6 @@ def maybe_map_to_meta_val(value): exported = fx_interpreter.run( fx_graph_module=graph_module, onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, - op_level_debug=self._resolved_onnx_exporter_options.op_level_debug, ) # Convert the exported result to ONNX ModelProto. onnx_model = exported.to_model_proto( @@ -1168,8 +1210,6 @@ def reusable(a: OrtBackendOptions, b: OrtBackendOptions): if a.export_options is not None and b.export_options is not None: return ( a.export_options.dynamic_shapes == b.export_options.dynamic_shapes - and a.export_options.op_level_debug - == b.export_options.op_level_debug and a.export_options.diagnostic_options == b.export_options.diagnostic_options and a.export_options.onnx_registry is b.export_options.onnx_registry diff --git a/torch/onnx/errors.py b/torch/onnx/errors.py index 2f5271596b5c28..f6e035a8a85f1d 100644 --- a/torch/onnx/errors.py +++ b/torch/onnx/errors.py @@ -2,41 +2,38 @@ from __future__ import annotations -import textwrap -from typing import TYPE_CHECKING - -from torch.onnx import _constants -from torch.onnx._internal import diagnostics - - -if TYPE_CHECKING: - from torch import _C __all__ = [ - "OnnxExporterError", "OnnxExporterWarning", - "CheckerError", "SymbolicValueError", "UnsupportedOperatorError", ] +import textwrap +from typing import TYPE_CHECKING -class OnnxExporterWarning(UserWarning): - """Base class for all warnings in the ONNX exporter.""" +if TYPE_CHECKING: + from torch import _C -class OnnxExporterError(RuntimeError): - """Errors raised by the ONNX exporter.""" + +class OnnxExporterWarning(UserWarning): + """Warnings in the ONNX exporter.""" -class CheckerError(OnnxExporterError): - """Raised when ONNX checker detects an invalid model.""" +class OnnxExporterError(RuntimeError): + """Errors raised by the ONNX exporter. This is the base class for all exporter errors.""" class UnsupportedOperatorError(OnnxExporterError): """Raised when an operator is unsupported by the exporter.""" + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed def __init__(self, name: str, version: int, supported_version: int | None): + from torch.onnx import _constants + from torch.onnx._internal import diagnostics + if supported_version is not None: diagnostic_rule: diagnostics.infra.Rule = ( diagnostics.rules.operator_supported_in_newer_opset_version @@ -60,6 +57,8 @@ def __init__(self, name: str, version: int, supported_version: int | None): class SymbolicValueError(OnnxExporterError): """Errors around TorchScript values and nodes.""" + # NOTE: This is legacy and is only used by the torchscript exporter + # Clean up when the torchscript exporter is removed def __init__(self, msg: str, value: _C.Value): message = ( f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 1b10cc28531a0d..ae33ddf58c6e09 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -150,7 +150,6 @@ def scaled_dot_product_attention( ), "is_causal and attn_mask cannot be set at the same time" assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - scale = symbolic_helper._maybe_get_const(scale, "f") if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 7370411ad3d221..37d7ee4b35a739 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -139,14 +139,14 @@ def disable_apex_o2_state_dict_hook(model: torch.nn.Module | torch.jit.ScriptFun @contextlib.contextmanager def setup_onnx_logging(verbose: bool): - is_originally_enabled = torch.onnx.is_onnx_log_enabled() - if is_originally_enabled or verbose: - torch.onnx.enable_log() + is_originally_enabled = _C._jit_is_onnx_log_enabled + if is_originally_enabled or verbose: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(True) try: yield finally: - if not is_originally_enabled: - torch.onnx.disable_log() + if not is_originally_enabled: # type: ignore[truthy-function] + _C._jit_set_onnx_log_enabled(False) @contextlib.contextmanager @@ -175,7 +175,7 @@ def _get_torch_export_args( def export( model: torch.nn.Module | torch.jit.ScriptModule | torch.jit.ScriptFunction, args: tuple[Any, ...] | torch.Tensor, - f: str | None = None, + f: str, *, kwargs: dict[str, Any] | None = None, export_params: bool = True, @@ -1125,7 +1125,7 @@ def _model_to_graph( module=module, ) except Exception as e: - torch.onnx.log("Torch IR graph at exception: ", graph) + _C._jit_onnx_log("Torch IR graph at exception: ", graph) raise is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)) @@ -1452,7 +1452,7 @@ def _get_module_attributes(module): try: attrs[k] = getattr(module, k) except AttributeError: - torch.onnx.log(f"Skipping module attribute '{k}'") + _C._jit_onnx_log(f"Skipping module attribute '{k}'") continue return attrs @@ -1642,20 +1642,8 @@ def _export( custom_opsets, ) if verbose: - torch.onnx.log("Exported graph: ", graph) + _C._jit_onnx_log("Exported graph: ", graph) onnx_proto_utils._export_file(proto, f, export_type, export_map) - # The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX, - # we can skip this check. - # If large model format export is enabled, proto will only contain data location instead of - # raw data and _check_onnx_proto() will fail because it can only handle the raw ONNX proto - # string in memory. - if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and ( - not val_use_external_data_format - ): - try: - _C._check_onnx_proto(proto) - except RuntimeError as e: - raise errors.CheckerError(e) from e finally: assert GLOBALS.in_onnx_export GLOBALS.in_onnx_export = False diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index a21f1ffbba7783..f489252f5a7b27 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -21,6 +21,7 @@ from typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union import numpy as np +import numpy.typing as npt import torch import torch._C._onnx as _C_onnx @@ -98,7 +99,7 @@ def _flatten_tuples(elem): # TODO(justinchuby): Add type checking by narrowing down the return type when input is None -def _to_numpy(elem) -> list | np.ndarray: +def _to_numpy(elem) -> list | npt.NDArray: if isinstance(elem, torch.Tensor): if elem.requires_grad: return elem.detach().cpu().numpy() diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index ee53d1c5fd765c..7354092dda4e02 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -6,22 +6,22 @@ future. """ -from torch.optim import lr_scheduler, swa_utils -from torch.optim._adafactor import Adafactor -from torch.optim.adadelta import Adadelta -from torch.optim.adagrad import Adagrad -from torch.optim.adam import Adam -from torch.optim.adamax import Adamax -from torch.optim.adamw import AdamW -from torch.optim.asgd import ASGD -from torch.optim.lbfgs import LBFGS -from torch.optim.nadam import NAdam -from torch.optim.optimizer import Optimizer -from torch.optim.radam import RAdam -from torch.optim.rmsprop import RMSprop -from torch.optim.rprop import Rprop -from torch.optim.sgd import SGD -from torch.optim.sparse_adam import SparseAdam +from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils +from torch.optim._adafactor import Adafactor as Adafactor +from torch.optim.adadelta import Adadelta as Adadelta +from torch.optim.adagrad import Adagrad as Adagrad +from torch.optim.adam import Adam as Adam +from torch.optim.adamax import Adamax as Adamax +from torch.optim.adamw import AdamW as AdamW +from torch.optim.asgd import ASGD as ASGD +from torch.optim.lbfgs import LBFGS as LBFGS +from torch.optim.nadam import NAdam as NAdam +from torch.optim.optimizer import Optimizer as Optimizer +from torch.optim.radam import RAdam as RAdam +from torch.optim.rmsprop import RMSprop as RMSprop +from torch.optim.rprop import Rprop as Rprop +from torch.optim.sgd import SGD as SGD +from torch.optim.sparse_adam import SparseAdam as SparseAdam Adafactor.__module__ = "torch.optim" diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 57dcbd85a83164..f0a10efefd12f5 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -92,7 +92,10 @@ class LRScheduler: _get_lr_called_within_step: bool = False def __init__( - self, optimizer: Optimizer, last_epoch=-1, verbose="deprecated" + self, + optimizer: Optimizer, + last_epoch: int = -1, + verbose="deprecated", ): # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): @@ -319,7 +322,7 @@ def __init__( self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], - last_epoch=-1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.optimizer = optimizer @@ -419,7 +422,7 @@ def __init__( self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], - last_epoch=-1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.optimizer = optimizer @@ -523,8 +526,8 @@ def __init__( self, optimizer: Optimizer, step_size: int, - gamma=0.1, - last_epoch=-1, + gamma: float = 0.1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.step_size = step_size @@ -582,8 +585,8 @@ def __init__( self, optimizer: Optimizer, milestones: Iterable[int], - gamma=0.1, - last_epoch=-1, + gamma: float = 0.1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.milestones = Counter(milestones) @@ -648,9 +651,9 @@ class ConstantLR(LRScheduler): def __init__( self, optimizer: Optimizer, - factor=1.0 / 3, - total_iters=5, - last_epoch=-1, + factor: float = 1.0 / 3, + total_iters: int = 5, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if factor > 1.0 or factor < 0: @@ -726,10 +729,10 @@ class LinearLR(LRScheduler): def __init__( self, optimizer: Optimizer, - start_factor=1.0 / 3, - end_factor=1.0, - total_iters=5, - last_epoch=-1, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + total_iters: int = 5, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if start_factor > 1.0 or start_factor <= 0: @@ -803,7 +806,11 @@ class ExponentialLR(LRScheduler): """ def __init__( - self, optimizer: Optimizer, gamma: float, last_epoch=-1, verbose="deprecated" + self, + optimizer: Optimizer, + gamma: float, + last_epoch: int = -1, + verbose="deprecated", ): # noqa: D107 self.gamma = gamma super().__init__(optimizer, last_epoch, verbose) @@ -859,7 +866,7 @@ def __init__( optimizer: Optimizer, schedulers: List[LRScheduler], milestones: List[int], - last_epoch=-1, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if len(schedulers) < 1: @@ -910,7 +917,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def step(self): + def step(self): # type: ignore[override] """Perform a step.""" self.last_epoch += 1 idx = bisect_right(self._milestones, self.last_epoch) @@ -992,9 +999,9 @@ class PolynomialLR(LRScheduler): def __init__( self, optimizer: Optimizer, - total_iters=5, - power=1.0, - last_epoch=-1, + total_iters: int = 5, + power: float = 1.0, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.total_iters = total_iters @@ -1074,8 +1081,8 @@ def __init__( self, optimizer: Optimizer, T_max: int, - eta_min=0.0, - last_epoch=-1, + eta_min: float = 0.0, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 self.T_max = T_max @@ -1179,7 +1186,7 @@ def __init__( group["lr"] for group in self._schedulers[-1].optimizer.param_groups ] - def step(self): + def step(self): # type: ignore[override] """Perform a step.""" for scheduler in self._schedulers: scheduler.step() @@ -1288,13 +1295,13 @@ def __init__( self, optimizer: Optimizer, mode: Literal["min", "max"] = "min", - factor=0.1, - patience=10, - threshold=1e-4, + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, threshold_mode: Literal["rel", "abs"] = "rel", - cooldown=0, + cooldown: int = 0, min_lr: Union[List[float], float] = 0, - eps=1e-8, + eps: float = 1e-8, verbose="deprecated", ): # noqa: D107 if factor >= 1.0: @@ -1525,16 +1532,16 @@ def __init__( optimizer: Optimizer, base_lr: Union[float, List[float]], max_lr: Union[float, List[float]], - step_size_up=2000, + step_size_up: int = 2000, step_size_down: Optional[int] = None, mode: Literal["triangular", "triangular2", "exp_range"] = "triangular", - gamma=1.0, + gamma: float = 1.0, scale_fn: Optional[Callable[[float], float]] = None, scale_mode: Literal["cycle", "iterations"] = "cycle", - cycle_momentum=True, - base_momentum=0.8, - max_momentum=0.9, - last_epoch=-1, + cycle_momentum: bool = True, + base_momentum: float = 0.8, + max_momentum: float = 0.9, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 # Attach optimizer @@ -1740,9 +1747,9 @@ def __init__( self, optimizer: Optimizer, T_0: int, - T_mult=1, - eta_min=0.0, - last_epoch=-1, + T_mult: int = 1, + eta_min: float = 0.0, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 if T_0 <= 0 or not isinstance(T_0, int): @@ -1959,15 +1966,15 @@ def __init__( total_steps: Optional[int] = None, epochs: Optional[int] = None, steps_per_epoch: Optional[int] = None, - pct_start=0.3, + pct_start: float = 0.3, anneal_strategy: Literal["cos", "linear"] = "cos", - cycle_momentum=True, + cycle_momentum: bool = True, base_momentum: Union[float, List[float]] = 0.85, max_momentum: Union[float, List[float]] = 0.95, - div_factor=25.0, - final_div_factor=1e4, - three_phase=False, - last_epoch=-1, + div_factor: float = 25.0, + final_div_factor: float = 1e4, + three_phase: bool = False, + last_epoch: int = -1, verbose="deprecated", ): # noqa: D107 # Validate optimizer diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index a417a44817ebfd..8f7993842c1009 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -245,17 +245,13 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: are supported. (default: None) .. note:: The foreach and fused implementations are typically faster than the for-loop, - single-tensor implementation. Thus, if the user has not specified BOTH flags - (i.e., when foreach = fused = None), we will attempt defaulting to the foreach - implementation when the tensors are all on CUDA. For example, if the user specifies - True for fused but nothing for foreach, we will run the fused implementation. If - the user specifies False for foreach but nothing for fused (or False for fused but - nothing for foreach), we will run the for-loop implementation. If the user specifies - True for both foreach and fused, we will prioritize fused over foreach, as it is - typically faster. We attempt to use the fastest, so the hierarchy goes fused -> - foreach -> for-loop. HOWEVER, since the fused implementation is relatively new, - we want to give it sufficient bake-in time, so we default to foreach and NOT - fused when the user has not specified either flag.""" + single-tensor implementation, with fused being theoretically fastest with both + vertical and horizontal fusion. As such, if the user has not specified either + flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach + implementation when the tensors are all on CUDA. Why not fused? Since the fused + implementation is relatively new, we want to give it sufficient bake-in time. + To specify fused, pass True for fused. To force running the for-loop + implementation, pass False for either foreach or fused. """ _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to capture in a CUDA graph. Passing True can impair ungraphed performance, diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 9b77ad7fe3eeaf..876f4e1d697bf2 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -34,8 +34,8 @@ def __init__( eps: float = 1e-8, weight_decay: float = 0, momentum: float = 0, - centered=False, - capturable=False, + centered: bool = False, + capturable: bool = False, foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index b270afce4d60a2..46af5ae77537ec 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -15,6 +15,7 @@ _use_grad_for_differentiable, DeviceDict, Optimizer, + ParamsT, ) @@ -24,12 +25,12 @@ class SGD(Optimizer): # noqa: D101 def __init__( self, - params, + params: ParamsT, lr: Union[float, Tensor] = 1e-3, momentum: float = 0, dampening: float = 0, weight_decay: float = 0, - nesterov=False, + nesterov: bool = False, *, maximize: bool = False, foreach: Optional[bool] = None, diff --git a/torch/overrides.py b/torch/overrides.py index 91b383f32483a6..7a568d7e22c1cd 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1252,8 +1252,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.vsplit: lambda input, indices_or_sections: -1, torch.vstack: lambda tensors, out=None: -1, torch.where: lambda condition, x=None, y=None: -1, - torch.wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, - torch.wrapped_quantized_linear_prepacked: ( + torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, + torch._wrapped_quantized_linear_prepacked: ( lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950 ), torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index 7845ffe39a2a2e..8856ad6c37ccf0 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -8,7 +8,6 @@ EXT2, EXT4, GLOBAL, - Pickler, PicklingError, STACK_GLOBAL, ) @@ -18,7 +17,18 @@ from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer -class PackagePickler(_Pickler): +class _PyTorchLegacyPickler(_Pickler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._persistent_id = None + + def persistent_id(self, obj): + if self._persistent_id is None: + return super().persistent_id(obj) + return self._persistent_id(obj) + + +class PackagePickler(_PyTorchLegacyPickler): """Package-aware pickler. This behaves the same as a normal pickler, except it uses an `Importer` @@ -113,6 +123,6 @@ def create_pickler(data_buf, importer, protocol=4): if importer is sys_importer: # if we are using the normal import library system, then # we can use the C implementation of pickle which is faster - return Pickler(data_buf, protocol=protocol) + return _PyTorchLegacyPickler(data_buf, protocol=protocol) else: return PackagePickler(importer, data_buf, protocol=protocol) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 939ae73a99afb4..4b0708c4a78f57 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -36,6 +36,28 @@ PROFILER_STEP_NAME = "ProfilerStep" +class _NumpyEncoder(json.JSONEncoder): + """ + Json encoder for numpy types (np.int, np.float, np.array etc.) + Returns default encoder if numpy is not available + """ + + def default(self, obj): + """Encode NumPy types to JSON""" + try: + import numpy as np + except ImportError: + return json.JSONEncoder.default(self, obj) + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return json.JSONEncoder.default(self, obj) + + def supported_activities(): """ Returns a set of supported profiler tracing activities. @@ -187,7 +209,9 @@ def start_trace(self): if kineto_available(): dist_info = self._get_distributed_info() if dist_info: - self.add_metadata_json("distributedInfo", json.dumps(dist_info)) + self.add_metadata_json( + "distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder) + ) if hasattr(torch, "_inductor"): import torch._inductor.config as inductor_config @@ -931,5 +955,6 @@ def _record_pg_config(self) -> None: ): pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info torch.autograd._record_function_with_args_enter( - "## process_group:init ##", json.dumps(pg_config_info) + "## process_group:init ##", + json.dumps(pg_config_info, cls=_NumpyEncoder), ) diff --git a/torch/serialization.py b/torch/serialization.py index 36e5c035ff1516..d937680c031c78 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,6 +11,7 @@ import sys import tarfile import tempfile +import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -54,10 +55,13 @@ "LoadEndianness", "get_default_load_endianness", "set_default_load_endianness", + "get_default_mmap_options", + "set_default_mmap_options", "clear_safe_globals", "get_safe_globals", "add_safe_globals", "safe_globals", + "skip_data", ] @@ -85,6 +89,22 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] +# _serialization_tls is used to store thread local state specific to serialization +# that needs to be propagated to other files, in particular we use this for +# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) +# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) +class _SerializationLocal(threading.local): + def __init__(self): + super().__init__() + self.map_location: Optional[MAP_LOCATION] = None + self.skip_data: bool = False + self.materialize_fake_tensors: bool = False + + +_serialization_tls = _SerializationLocal() + + class SourceChangeWarning(Warning): pass @@ -163,9 +183,9 @@ def get_default_mmap_options() -> int: return _default_mmap_options -def set_default_mmap_options(flags: int): +class set_default_mmap_options: """ - Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. + Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. Please open an issue if you need any other option to be added here. @@ -176,17 +196,27 @@ def set_default_mmap_options(flags: int): Args: flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` """ - global _default_mmap_options - if IS_WINDOWS: - raise RuntimeError( - "Changing the default mmap options is currently not supported for Windows" - ) - if flags != MAP_PRIVATE and flags != MAP_SHARED: - raise ValueError( - "Invalid argument in function set_default_mmap_options, " - f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" - ) - _default_mmap_options = flags + + def __init__(self, flags: int) -> None: + if IS_WINDOWS: + raise RuntimeError( + "Changing the default mmap options is currently not supported for Windows" + ) + if flags != MAP_PRIVATE and flags != MAP_SHARED: + raise ValueError( + "Invalid argument in function set_default_mmap_options, " + f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}" + ) + global _default_mmap_options + self.prev = _default_mmap_options + _default_mmap_options = flags + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + global _default_mmap_options + _default_mmap_options = self.prev def clear_safe_globals() -> None: @@ -256,6 +286,47 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ +class skip_data: + """ + Context-manager that skips writing storage bytes for ``torch.save`` calls. + + Storages will still be saved, but the space that their bytes would usually be written to + will be empty space. The storage bytes can then be populated in a separate pass. + + .. warning:: + The ``skip_data`` context manager is an early prototype and is subject to change. + + Args: + materialize_fake_tensors: Whether to materialize FakeTensors. + + Example: + >>> # xdoctest: +SKIP("NamedTemporaryFile on Windows") + >>> import tempfile + >>> t = torch.randn(2, 3) + >>> with tempfile.NamedTemporaryFile() as f: + ... with torch.serialization.skip_data(): + ... torch.save(t, f.name) + ... torch.load(f.name, weights_only=True) + tensor([[0., 0., 0.], + [0., 0., 0.]]) + """ + + def __init__(self, materialize_fake_tensors: bool = False): + self.materialize_fake_tensors = materialize_fake_tensors + + def __enter__(self): + global _serialization_tls + self._old_skip_data = _serialization_tls.skip_data + self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors + _serialization_tls.skip_data = True + _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors + + def __exit__(self, type, value, tb): + global _serialization_tls + _serialization_tls.skip_data = self._old_skip_data + _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors + + def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -785,6 +856,11 @@ def save( ) return else: + global _serialization_tls + if _serialization_tls.skip_data: + raise RuntimeError( + "Cannot use skip_data=True with _use_new_zipfile_serialization=False" + ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -929,8 +1005,12 @@ def persistent_id(obj: Any) -> Optional[Tuple]: pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) pickle_module.dump(sys_info, f, protocol=pickle_protocol) - pickler = pickle_module.Pickler(f, protocol=pickle_protocol) - pickler.persistent_id = persistent_id + + class PyTorchLegacyPickler(pickle_module.Pickler): + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchLegacyPickler(f, protocol=pickle_protocol) pickler.dump(obj) serialized_storage_keys = sorted(serialized_storages.keys()) @@ -943,7 +1023,13 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): +def _save( + obj, + zip_file, + pickle_module, + pickle_protocol, + _disable_byteorder_record, +): serialized_storages = {} id_map: Dict[int, str] = {} @@ -978,7 +1064,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if storage.data_ptr() != 0: + if str(storage.device) != "meta" and storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -989,7 +1075,10 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - location = location_tag(storage) + if hasattr(obj, "_fake_device") and obj._fake_device is not None: + location = str(obj._fake_device) + else: + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -998,8 +1087,12 @@ def persistent_id(obj): # Write the pickle data for `obj` data_buf = io.BytesIO() - pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) - pickler.persistent_id = persistent_id + + class PyTorchPickler(pickle_module.Pickler): # type: ignore[name-defined] + def persistent_id(self, obj): + return persistent_id(obj) + + pickler = PyTorchPickler(data_buf, protocol=pickle_protocol) pickler.dump(obj) data_value = data_buf.getvalue() zip_file.write_record("data.pkl", data_value, len(data_value)) @@ -1015,14 +1108,18 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - zip_file.write_record(name, storage, num_bytes) + global _serialization_tls + if _serialization_tls.skip_data: + zip_file.write_record_metadata(name, num_bytes) + else: + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file + zip_file.write_record(name, storage, num_bytes) def load( @@ -1172,6 +1269,14 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE + global _serialization_tls + skip_data = _serialization_tls.skip_data + if skip_data: + raise RuntimeError( + "`torch.load` called within a torch.serialization.skip_data context manager " + "is not supported yet. Please call torch.load outside the skip_data context manager." + ) + if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1746,9 +1851,10 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - torch._utils._thread_local_state.map_location = map_location + global _serialization_tls + _serialization_tls.map_location = map_location result = unpickler.load() - del torch._utils._thread_local_state.map_location + _serialization_tls.map_location = None torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 51784b111941ee..6626b6d1f3aaf2 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -750,7 +750,7 @@ def general_hamming(M, .. math:: w_n = 1 - 0.36358 \cos{(z_n)} + 0.48917 \cos{(2z_n)} - 0.13659 \cos{(3z_n)} + 0.01064 \cos{(4z_n)} -where ``z_n = 2 \u03c0 n/ M``. +where :math:`z_n = \frac{2 \pi n}{M}`. """, """ diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 091e91d37f604a..f919718bb5dc66 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -89,14 +89,14 @@ def make_triton_contiguous(t): """Return input as a triton-contiguous tensor. A triton-contiguous tensor is defined as a tensor that has strides - with minimal value equal to 1. + with minimal value smaller than or equal to 1. While triton kernels support triton-non-contiguous tensors (all - strides being greater than 1 or having 0 strides) arguments, a - considerable slow-down occurs because tensor data is copied - element-wise rather than chunk-wise. + strides being greater than 1) arguments, a considerable slow-down + occurs because tensor data is copied element-wise rather than + chunk-wise. Zero strides is assumed to not have this defect. """ - if min(t.stride()) != 1: + if min(t.stride()) > 1: # TODO: investigate if contiguity along other axes than the # last one can be beneficial for performance return t.contiguous() @@ -1097,6 +1097,8 @@ def _int_bsr_dense_addmm( *, beta=1, alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, skip_checks: bool = False, max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, @@ -1120,6 +1122,8 @@ def _int_bsr_dense_addmm( dense, beta=beta, alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, out=out, skip_checks=skip_checks, max_grid=max_grid, @@ -1134,11 +1138,21 @@ def bsr_dense_addmm( *, beta=1, alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, skip_checks: bool = False, max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None, meta: Optional[dict] = None, ): + """Compute + + out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1) + + where left_alpha, right_alpha are (* + 1)-D tensors when + specified, otherwise, these are treated as tensors filled with + ones. + """ f_name = "bsr_dense_addmm" values = bsr.values() crow_indices = bsr.crow_indices() @@ -1150,8 +1164,8 @@ def bsr_dense_addmm( # todo: implement checks + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) if out is None: - original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: @@ -1163,6 +1177,30 @@ def bsr_dense_addmm( out.mul_(beta) return out + left_alpha_is_one = False + right_alpha_is_one = False + if left_alpha is None: + left_alpha_is_one = True + left_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand( + *original_batch_dims_broadcasted, M, N + ) + + if right_alpha is None: + right_alpha_is_one = True + right_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand( + *original_batch_dims_broadcasted, M, N + ) + assert left_alpha.stride()[-1] == 0 + assert right_alpha.stride()[-2] == 0 + if meta is None: sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) meta = bsr_dense_addmm_meta( @@ -1178,9 +1216,16 @@ def bsr_dense_addmm( ) out_backup = out - crow_indices, col_indices, values, input, dense, out = prepare_inputs( - bsr, input, dense, out - ) + ( + crow_indices, + col_indices, + values, + input, + dense, + left_alpha, + right_alpha, + out, + ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out) BM, BK = blocksize SPLIT_N = meta.get("SPLIT_N", N // BM) @@ -1191,6 +1236,9 @@ def bsr_dense_addmm( dense = tile_to_blocksize(dense, (BK, BN)) input = tile_to_blocksize(input, (BM, BN)) + left_alpha = tile_to_blocksize(left_alpha, (BM, BN)) + right_alpha = tile_to_blocksize(right_alpha, (BM, BN)) + dot_out_dtype = { torch.float16: tl.float32, torch.bfloat16: tl.float32, @@ -1216,6 +1264,8 @@ def bsr_dense_addmm( col_indices: (0, None, None), input: (0, -3, -4), dense: (0, -3, None), + left_alpha: (0, -3, -4), + right_alpha: (0, -3, -4), out: (0, -3, -4), } @@ -1229,6 +1279,8 @@ def kernel(grid, *sliced_tensors): beta_is_one=beta == 1, beta_is_nonzero=beta != 0, alpha_is_one=alpha == 1, + left_alpha_is_one=left_alpha_is_one, + right_alpha_is_one=right_alpha_is_one, BLOCKSIZE_ROW=BM, BLOCKSIZE_INNER=BK, BLOCKSIZE_COL=BN, @@ -2278,6 +2330,22 @@ def _bsr_strided_addmm_kernel( dense_row_block_stride, dense_col_block_stride, # dense epilogue + # left_alpha prologue + left_alpha_ptr, + left_alpha_batch_stride, + left_alpha_tiled_row_stride, + left_alpha_tiled_col_stride: tl.constexpr, + left_alpha_row_block_stride, + left_alpha_col_block_stride: tl.constexpr, + # left_alpha epilogue + # right_alpha prologue + right_alpha_ptr, + right_alpha_batch_stride, + right_alpha_tiled_row_stride: tl.constexpr, + right_alpha_tiled_col_stride, + right_alpha_row_block_stride: tl.constexpr, + right_alpha_col_block_stride, + # right_alpha epilogue # output prologue output_ptr, output_batch_stride, @@ -2291,6 +2359,8 @@ def _bsr_strided_addmm_kernel( beta_is_one: tl.constexpr, beta_is_nonzero: tl.constexpr, alpha_is_one: tl.constexpr, + left_alpha_is_one: tl.constexpr, + right_alpha_is_one: tl.constexpr, BLOCKSIZE_ROW: tl.constexpr, BLOCKSIZE_COL: tl.constexpr, BLOCKSIZE_INNER: tl.constexpr, @@ -2299,6 +2369,12 @@ def _bsr_strided_addmm_kernel( GROUP_SIZE_ROW: tl.constexpr, SPLIT_N: tl.constexpr, ): + # left/right_alpha tensors are originally (* + 1)-dimensional + assert left_alpha_tiled_col_stride == 0 + assert left_alpha_col_block_stride == 0 + assert right_alpha_tiled_row_stride == 0 + assert right_alpha_row_block_stride == 0 + batch_pid = tl.program_id(axis=2) row_block_pid = tl.program_id(axis=0) col_block_pid = tl.program_id(axis=1) @@ -2324,17 +2400,6 @@ def _bsr_strided_addmm_kernel( inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) col_block_arange = tl.arange(0, BLOCKSIZE_COL) - if beta_is_nonzero: - # Pointers are set to exact write-to locations - input_ptrs = ( - input_ptr - + input_batch_stride * batch_pid - + input_tiled_row_stride * row_block_pid - + input_tiled_col_stride * col_block_pid - + input_row_block_stride * row_block_arange[:, None] - + input_col_block_stride * col_block_arange[None, :] - ) - # Pointers are set to the first block of the current row. values_block_ptrs = ( values_ptr @@ -2371,14 +2436,7 @@ def _bsr_strided_addmm_kernel( + col_indices_stride * nnz_offset ) - # alpha is never 0 - if beta_is_nonzero: - output_acc_block = tl.load(input_ptrs).to(acc_dtype) # type: ignore[possibly-undefined] - if not (beta_is_one and alpha_is_one): - beta_alpha = beta / alpha - output_acc_block *= beta_alpha - else: - output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) + output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_COL), dtype=acc_dtype) for _ in range(row_nnz): values_block = tl.load(values_block_ptrs) @@ -2402,6 +2460,42 @@ def _bsr_strided_addmm_kernel( if not alpha_is_one: output_acc_block *= alpha + if not left_alpha_is_one: + left_alpha_ptrs = ( + left_alpha_ptr + + left_alpha_batch_stride * batch_pid + + left_alpha_tiled_row_stride * row_block_pid + + left_alpha_tiled_col_stride * col_block_pid + + left_alpha_row_block_stride * row_block_arange[:, None] + + left_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(left_alpha_ptrs) + + if not right_alpha_is_one: + right_alpha_ptrs = ( + right_alpha_ptr + + right_alpha_batch_stride * batch_pid + + right_alpha_tiled_row_stride * row_block_pid + + right_alpha_tiled_col_stride * col_block_pid + + right_alpha_row_block_stride * row_block_arange[:, None] + + right_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(right_alpha_ptrs) + + if beta_is_nonzero: + input_ptrs = ( + input_ptr + + input_batch_stride * batch_pid + + input_tiled_row_stride * row_block_pid + + input_tiled_col_stride * col_block_pid + + input_row_block_stride * row_block_arange[:, None] + + input_col_block_stride * col_block_arange[None, :] + ) + if beta_is_one: + output_acc_block += tl.load(input_ptrs) + else: + output_acc_block += beta * tl.load(input_ptrs) + # write back the result tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty)) diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index a97d9c502dc349..3d353c00cc21bd 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -507,9 +507,7 @@ def bench(meta, bsr=bsr, dense=dense): def test_func(): return bsr_scatter_mm(bsr, dense, indices_data=indices_data) - ms_min = triton.testing.do_bench( - test_func, warmup=500, rep=100, fast_flush=False - ) + ms_min = triton.testing.do_bench(test_func, warmup=500, rep=100) return ms_min @@ -601,6 +599,8 @@ def tune_bsr_dense_addmm( *, beta=1, alpha=1, + left_alpha=None, + right_alpha=None, out=None, store=False, verbose=False, @@ -660,10 +660,18 @@ def tune_bsr_dense_addmm( def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out): def test_func(): return bsr_dense_addmm( - input, bsr, dense, beta=beta, alpha=alpha, meta=meta, out=out + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + meta=meta, + out=out, ) - return triton.testing.do_bench(test_func, warmup=500, rep=100, fast_flush=False) + return triton.testing.do_bench(test_func, warmup=500, rep=100) # The step function that increments a specified meta parameter: def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK): @@ -725,6 +733,8 @@ def optimize_bsr_dense_addmm( bk, beta=1, alpha=1, + use_left_alpha=False, + use_right_alpha=False, dtype=torch.float16, device="cuda", sparsity=0.5, @@ -738,12 +748,18 @@ def optimize_bsr_dense_addmm( ).to_sparse_bsr((bm, bk)) dense = make_tensor(k, n, dtype=dtype, device=device) input = make_tensor(m, n, dtype=dtype, device=device) + left_alpha = make_tensor(m, dtype=dtype, device=device) if use_left_alpha else None + right_alpha = ( + make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None + ) tune_bsr_dense_addmm( input, bsr, dense, beta=beta, alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, store=True, force=force, verbose=verbose, @@ -866,9 +882,7 @@ def test_func(): else: raise NotImplementedError(op) - ms_min = triton.testing.do_bench( - test_func, warmup=500, rep=100, fast_flush=False - ) + ms_min = triton.testing.do_bench(test_func, warmup=500, rep=100) return ms_min diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 0017a10e6771bb..5e8a632633e09a 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -295,7 +295,7 @@ def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: else: return dense_input - def to_dense(self): + def to_dense(self): # type:ignore[override] col = self.shape[-1] return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) @@ -420,7 +420,7 @@ def from_dense( requires_grad=original_tensor.requires_grad, ) - def to_dense(self): + def to_dense(self): # type: ignore[override] assert self.meta is not None and self.packed is not None return ( sparse_semi_structured_to_dense_cutlass( diff --git a/torch/storage.py b/torch/storage.py index b6ba608c16e5ca..8848649905f936 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,6 +39,8 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -649,6 +651,8 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False + # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) + _fake_device: _Optional[torch.device] = None dtype: torch.dtype diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 8527084f4afa82..2d18da71ec2bda 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -1,7 +1,10 @@ # mypy: ignore-errors +import collections + import torch from torch.testing._internal.common_utils import TEST_WITH_ROCM +from torch.testing._internal.common_utils import TestCase class AutocastTestLists: @@ -234,6 +237,7 @@ def __init__(self, dev): torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn), ] + class AutocastCPUTestLists: # Supplies ops and arguments for test_autocast_* in test/test_cpu.py def __init__(self, dev): @@ -316,6 +320,7 @@ def __init__(self, dev): torch.randn((n, n, n), device=dev, dtype=torch.float32), torch.randn((n, n, n), device=dev, dtype=torch.float32))), ("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32), + ("_addmm_activation", mat1_fp32 + mat2_fp32 + mat3_fp32, {"beta": 1, "alpha": 1, "use_gelu": True}), ("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32), torch.randn((n, n, n), device=dev, dtype=torch.float32))), ("conv_tbc", (torch.randn((10, 7, 3), device=dev, dtype=torch.float32), @@ -368,3 +373,103 @@ def __init__(self, dev): ("cat", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), ("stack", (pointwise0_bf16 + pointwise1_fp32,), (pointwise0_fp16 + pointwise1_fp32,)), ] + + +class TestAutocast(TestCase): + def args_maybe_kwargs(self, op_with_args): + if len(op_with_args) == 2: + return op_with_args[0], op_with_args[1], {} + else: + return op_with_args[0], op_with_args[1], op_with_args[2] + + def _run_autocast_outofplace( + self, + op, + args, + run_as_type, + device, + out_type=None, + module=torch, + add_kwargs=None, + amp_dtype=torch.bfloat16, + ): + # helper to cast args + def cast(val, to_type): + if isinstance(val, torch.Tensor): + return val.to(to_type) if val.is_floating_point() else val + elif isinstance(val, collections.abc.Iterable): + return type(val)(cast(v, to_type) for v in val) + else: + return val + + if add_kwargs is None: + add_kwargs = {} + + self.assertFalse(torch.is_autocast_enabled(device_type=device)) + with torch.amp.autocast(device_type=device, dtype=amp_dtype): + self.assertTrue(torch.is_autocast_enabled(device_type=device)) + + out_type = out_type if out_type is not None else run_as_type + output = output_method = None + + # Try module.* variant, if requested: + if module is not None and hasattr(module, op): + output = getattr(module, op)(*args, **add_kwargs) + if isinstance(output, torch.Tensor): + self.assertTrue( + out_type == output.dtype, + f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}", + ) + # Try Tensor.* variant: + if hasattr(torch.Tensor, op): + output_method = getattr(args[0], op)(*args[1:], **add_kwargs) + if isinstance(output_method, torch.Tensor): + self.assertTrue( + out_type == output_method.dtype, + f"autocast for torch.{op} produced {output_method.dtype}, should produce torch.{out_type}", + ) + + self.assertTrue( + (output is not None) or (output_method is not None), + f"{op} not found as an attribute on either Tensor or the requested module {module}", + ) + + # Accounts for ops that return Tensors, iterables, and other non-Tensors. + # For example, lstm_cell returns a tuple and equal returns bool. + def compare(first, second): + if isinstance(first, torch.Tensor): + return torch.equal(first, second) + elif isinstance(first, collections.abc.Iterable): + return all(compare(f, s) for f, s in zip(first, second)) + else: + return first == second + + # If both torch.* and Tensor.* variants were found, check outputs are identical + if (output is not None) and (output_method is not None): + self.assertTrue(type(output) == type(output_method)) + comparison = compare(output, output_method) + self.assertTrue( + comparison, f"torch.{op} result did not match Tensor.{op} result" + ) + + # Compare numerics to Python-side "autocasting" that (we expect) does the same thing + # as the C++-side autocasting, and should be bitwise accurate. + output_to_compare = output if output is not None else output_method + with torch.amp.autocast(device_type=device, enabled=False): + self.assertFalse( + torch.is_autocast_enabled(device_type=device) + ) + + if module is not None and hasattr(module, op): + control = getattr(module, op)( + *cast(args, run_as_type), **add_kwargs + ) + else: + control = getattr(args[0].to(run_as_type), op)( + *cast(args[1:], run_as_type), **add_kwargs + ) + self.assertTrue(type(output_to_compare) == type(control)) + comparison = compare(output_to_compare, control) + self.assertTrue(comparison, f"torch.{op} result did not match control") + self.assertTrue(torch.is_autocast_enabled(device_type=device)) + self.assertFalse(torch.is_autocast_enabled(device_type=device)) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 709ce32b3ec814..d762156e61a91a 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -750,12 +750,16 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo # Replace your privateuse1 backend name with 'privateuse1' if is_privateuse1_backend_available(): privateuse1_backend_name = torch._C._get_privateuse1_backend_name() - except_for = [ - "privateuse1" if x == privateuse1_backend_name else x for x in except_for - ] - only_for = [ - "privateuse1" if x == privateuse1_backend_name else x for x in only_for - ] + except_for = ( + ["privateuse1" if x == privateuse1_backend_name else x for x in except_for] + if except_for is not None + else None + ) + only_for = ( + ["privateuse1" if x == privateuse1_backend_name else x for x in only_for] + if only_for is not None + else None + ) if except_for: device_type_test_bases = filter( @@ -1642,6 +1646,10 @@ def expectedFailureMeta(fn): return skipIfTorchDynamo()(expectedFailure("meta")(fn)) +def expectedFailureMPS(fn): + return expectedFailure("mps")(fn) + + def expectedFailureXLA(fn): return expectedFailure("xla")(fn) @@ -1883,24 +1891,6 @@ def skipMPS(fn): return skipMPSIf(True, "test doesn't work on MPS backend")(fn) -def skipMPSVersionIfLessThan(major: int, minor: int): - def dec_fn(fn): - @wraps(fn) - def wrap_fn(self, *args, **kwargs): - if self.device_type == "mps": - if not torch.backends.mps.is_macos_or_newer(major, minor): - reason = ( - f"MPS test is skipped for MacOS versions < {major}.{minor} " - ) - raise unittest.SkipTest(reason) - - return fn(self, *args, **kwargs) - - return wrap_fn - - return dec_fn - - def skipHPU(fn): return skipHPUIf(True, "test doesn't work on HPU backend")(fn) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 26bdcce6103120..9ec38c9ca671c2 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -340,9 +340,9 @@ def requires_mpi(): ) -def skip_if_rocm(func): +def skip_if_rocm_multiprocess(func): """Skips a test for ROCm""" - func.skip_if_rocm = True + func.skip_if_rocm_multiprocess = True @wraps(func) def wrapper(*args, **kwargs): @@ -561,8 +561,14 @@ def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> if methodName != "runTest": method_name = methodName super().__init__(method_name) - fn = getattr(self, method_name) - setattr(self, method_name, self.join_or_run(fn)) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + except AttributeError as e: + if methodName != 'runTest': + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e def setUp(self) -> None: super().setUp() @@ -1014,8 +1020,14 @@ def __init__(self, method_name: str = "runTest", methodName: str = "runTest") -> if methodName != "runTest": method_name = methodName super().__init__(method_name) - fn = getattr(self, method_name) - setattr(self, method_name, self.join_or_run(fn)) + try: + fn = getattr(self, method_name) + setattr(self, method_name, self.join_or_run(fn)) + except AttributeError as e: + if methodName != 'runTest': + # we allow instantiation with no explicit method name + # but not an *incorrect* or missing method name + raise ValueError(f"no such test method in {self.__class__}: {methodName}") from e def perThreadSetUp(self): # super().setUp() # TestCase.setUp() calls torch.manual_seed() diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index fe02eeeabb1baf..f9eff69767931c 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -34,7 +34,6 @@ FSDPParamGroup, RegisterPostBackwardFunction, ) -from torch.distributed._tensor import distribute_tensor, DTensor, Shard from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import TrainingState @@ -46,6 +45,7 @@ ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap +from torch.distributed.tensor import distribute_tensor, DTensor, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7081637d4a3e81..5f23ec4475544f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13,6 +13,7 @@ import torch import numpy as np +import numpy.typing as npt from torch import inf, nan from typing import Any, Dict, List, Tuple, Union, Sequence @@ -20,7 +21,7 @@ from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - all_types, empty_types, complex_types_and, integral_types, custom_types, + empty_types, complex_types_and, integral_types, custom_types, ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -28,7 +29,7 @@ skipCPUIfNoMklSparse, toleranceOverride, tol) from torch.testing._internal.common_cuda import ( - PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, _get_torch_rocm_version, ) @@ -1573,7 +1574,7 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs): ((S, S), (0, 1), False), ) # Test large inputs to check numerical stability - lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64) else (None,) + lows = (None, 1e3, 1e6) if dtype in (torch.float32, torch.float64, torch.complex64, torch.complex128) else (None,) for low in lows: high = low * 2 if low is not None else None for shape, dim, keepdim in inputs: @@ -2849,28 +2850,29 @@ def error_inputs_multinomial(op_info, device, **kwargs): rep_arg = (False, True) if torch.device(device).type == 'cpu' else (False,) - for rep in rep_arg: - kwargs = {'num_samples': 2, 'replacement': rep} + if torch.device(device).type == 'cpu': + for rep in rep_arg: + kwargs = {'num_samples': 2, 'replacement': rep} - for shape in inputs: - # error case when input tensor contains `inf`, `nan` or negative element - yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), - error_regex=err_msg1 if rep is False else err_msg2) + for shape in inputs: + # error case when input tensor contains `inf`, `nan` or negative element + yield ErrorInput(SampleInput(torch.tensor(shape), kwargs=kwargs), + error_regex=err_msg1 if rep is False else err_msg2) - # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input - x = torch.zeros(3, device=device) - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 1-D input + x = torch.zeros(3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) - # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input - x = torch.zeros(3, 3, device=device) - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + # error case for the invalid multinomial distribution (sum of probabilities <= 0), 2-D input + x = torch.zeros(3, 3, device=device) + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) - # error case for the invalid multinomial distribution - x[1, :] = 1 - yield ErrorInput(SampleInput(x, kwargs=kwargs), - error_regex=err_msg2) + # error case for the invalid multinomial distribution + x[1, :] = 1 + yield ErrorInput(SampleInput(x, kwargs=kwargs), + error_regex=err_msg2) def error_inputs_gradient(op_info, device, **kwargs): for dtype in [torch.long, torch.float32, torch.complex64]: @@ -4540,7 +4542,7 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar ) def sample_inputs_rms_norm(opinfo, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, high=1000) # Ordered as input shape, normalized_shape and a kwarg dict for eps cases: Tuple[Tuple[int], Tuple[int], dict] = ( # type: ignore[assignment] @@ -8775,8 +8777,13 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] samples = [] + gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support + if TEST_WITH_ROCM and dtype == torch.float32: + causal_options = [False] # FIXME: Large errors with causal+fp32 + else: + causal_options = [True, False] for qkv_shape, is_causal, dropout_p, enable_gqa in product( - qkv_shapes, [True, False], [0.0, 0.5], [True, False]): + qkv_shapes, causal_options, [0.0, 0.5], gqa_options): shape_q, shape_kv = qkv_shape samples.append(SampleInput( make(shape_q), @@ -8806,14 +8813,15 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ dropout_p=0.0) ) - samples.append( - SampleInput( - make((batch, num_heads_q_gqa, seq_q, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - enable_gqa=True + if not TEST_WITH_ROCM: + samples.append( + SampleInput( + make((batch, num_heads_q_gqa, seq_q, head_dim)), + make((batch, num_heads_kv_gqa, seq_kv, head_dim)), + make((batch, num_heads_kv_gqa, seq_kv, head_dim)), + enable_gqa=True + ) ) - ) yield from samples @@ -9337,7 +9345,16 @@ def _set_rightmost_arg_types( if rightmost_supports_tensor: self._rightmost_arg_types.append(ForeachRightmostArgType.Tensor) - def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs): + def _sample_rightmost_arg( + self, + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ): if rightmost_arg_type == ForeachRightmostArgType.TensorList: return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] if rightmost_arg_type == ForeachRightmostArgType.Tensor: @@ -9357,21 +9374,25 @@ def sample_float(): high = 2 if should_use_simpler_scalars else 9 if rightmost_arg_type == ForeachRightmostArgType.ScalarList: - return [ - [random.randint(0, high) + 1 for _ in range(num_tensors)], - [sample_float() for _ in range(num_tensors)], - [complex(sample_float(), sample_float()) for _ in range(num_tensors)], - [True for _ in range(num_tensors)], - [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)], - [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)], - ] + scalarlist_list = [] + scalarlist_list.append([random.randint(0, high) + 1 for _ in range(num_tensors)]) + + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalarlist_list.append([sample_float() for _ in range(num_tensors)]) + if allow_higher_dtype_scalars or dtype.is_complex: + scalarlist_list.append([complex(sample_float(), sample_float()) for _ in range(num_tensors)]) + scalarlist_list.append([1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)]) + scalarlist_list.append([True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)]) + return scalarlist_list if rightmost_arg_type == ForeachRightmostArgType.Scalar: - return ( - random.randint(1, high + 1), - sample_float(), - True, - complex(sample_float(), sample_float()), - ) + scalars = [] + scalars.append(random.randint(1, high + 1)) + if allow_higher_dtype_scalars or dtype.is_floating_point: + scalars.append(sample_float()) + if allow_higher_dtype_scalars or dtype.is_complex: + scalars.append(complex(sample_float(), sample_float())) + scalars.append(True) + return scalars raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): @@ -9451,6 +9472,7 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * assert "num_input_tensors" not in kwargs _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) for rightmost_arg_type in self._rightmost_arg_types: zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) zero_size_foreach_inputs_kwargs["zero_size"] = True @@ -9462,8 +9484,14 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * ] args.append( self._sample_rightmost_arg( - opinfo, ForeachRightmostArgType.TensorList, device, dtype, NUM_SIZE0_TENSORS, - **zero_size_foreach_inputs_kwargs)[0]) + opinfo, + ForeachRightmostArgType.TensorList, + device, + dtype, + NUM_SIZE0_TENSORS, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **zero_size_foreach_inputs_kwargs, + )[0]) kwargs = self._sample_kwargs( opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype) else: @@ -9482,6 +9510,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad _foreach_inputs_kwargs["zero_size"] = False + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) # add empty tensor interspersion to test fully fixing #100701 for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product( @@ -9500,7 +9529,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): for _ in range(self.arity - 2) ] rightmost_arg_list = self._sample_rightmost_arg( - opinfo, rightmost_arg_type, device, dtype, num_tensors, + opinfo, rightmost_arg_type, device, dtype, num_tensors, allow_higher_dtype_scalars, **_foreach_inputs_kwargs) for rightmost_arg in rightmost_arg_list: args.append(rightmost_arg) @@ -9554,6 +9583,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): assert isinstance(num_input_tensors, list) _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad + _allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) for num_tensors, ord, out_dtype in product( num_input_tensors, @@ -9584,24 +9614,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) -class foreach_lerp_sample_func(foreach_inputs_sample_func): - def _sample_rightmost_arg(self, opinfo, rightmost_arg_type, device, dtype, num_tensors, **_foreach_inputs_kwargs): - if rightmost_arg_type == ForeachRightmostArgType.TensorList: - return [sample_inputs_foreach(None, device, dtype, num_tensors, **_foreach_inputs_kwargs)] - if rightmost_arg_type == ForeachRightmostArgType.ScalarList: - return [ - [random.randint(0, 9) + 1 for _ in range(num_tensors)], - [1.0 - random.random() for _ in range(num_tensors)], - [complex(1.0 - random.random(), 1.0 - random.random()) for _ in range(num_tensors)], - [True for _ in range(num_tensors)], - [1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 3)], - [True, 1, 2.0, 3.0 + 4.5j] + [3.0 for _ in range(num_tensors - 4)], - ] - if rightmost_arg_type == ForeachRightmostArgType.Scalar: - return [random.random()] - raise AssertionError(f"Invalid rightmost_arg_type of {rightmost_arg_type}") - - class foreach_pointwise_sample_func(foreach_inputs_sample_func): def __init__( @@ -9636,6 +9648,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): assert isinstance(num_input_tensors, list) _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad + allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) for num_tensors, rightmost_arg_type in itertools.product(num_input_tensors, self._rightmost_arg_types): input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) @@ -9644,7 +9657,15 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): for _ in range(2 - int(rightmost_arg_type == ForeachRightmostArgType.TensorList)) ] rightmost_arg_list = self._sample_rightmost_arg( - opinfo, rightmost_arg_type, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) + opinfo, + rightmost_arg_type, + device, + dtype, + num_tensors, + zero_size=False, + allow_higher_dtype_scalars=allow_higher_dtype_scalars, + **_foreach_inputs_kwargs, + ) for rightmost_arg in rightmost_arg_list: kwargs = {} if rightmost_arg_type == ForeachRightmostArgType.TensorList: @@ -10113,14 +10134,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_inplace", dtypes=integral_types_and(torch.bool), ), - # FIXME: fails check - # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.cpp#L508-L510 - DecorateInfo( - unittest.skip("Skipped!"), - "TestForeach", - "test_parity", - dtypes=(torch.bool,), - ), ), ), ForeachFuncInfo( @@ -10166,14 +10179,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10226,14 +10231,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10286,14 +10283,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10374,14 +10363,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10461,14 +10442,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10521,14 +10494,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=integral_types_and(torch.bool) + complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10580,75 +10545,18 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_dispatch_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_inplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_outplace", - device_type="cuda", - dtypes=complex_types(), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_inplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_outplace", - device_type="cuda", - dtypes=complex_types(), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_inplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_outplace", - device_type="cuda", - dtypes=complex_types(), - ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_autodiff", - device_type="cuda", - dtypes=(torch.complex128,), - ), ), ), ForeachFuncInfo( @@ -10694,14 +10602,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10827,14 +10727,6 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10863,86 +10755,36 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): "test_dispatch_meta_inplace", dtypes=complex_types() + integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_inplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_meta_outplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", dtypes=complex_types() + integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_inplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_dispatch_symbolic_meta_outplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=complex_types() + integral_types_and(torch.bool), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_inplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), DecorateInfo( unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=complex_types(), ), - DecorateInfo( - unittest.expectedFailure, - "TestMeta", - "test_meta_outplace", - device_type="cuda", - dtypes=(torch.bfloat16,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types() + (torch.bfloat16,), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -10965,14 +10807,15 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # These tests fail with aten._local_scalar_dense not being implemented. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16)), # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool, torch.bfloat16, torch.float16, torch.float64)), ), ), ForeachFuncInfo( @@ -11002,13 +10845,12 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), ), ), ForeachFuncInfo( @@ -11020,22 +10862,13 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - # fails with div_cpu is not implemented with ComplexHalf - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.float16,), device_type='cpu'), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.float16,), device_type='cpu'), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=(torch.float16,), device_type='cpu'), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.float16,), device_type='cpu'), + dtypes=integral_types_and(torch.bool)), ), ), ForeachFuncInfo( @@ -11045,22 +10878,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -11083,22 +10916,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -11122,22 +10955,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=False, supports_forward_ad=False, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -11161,22 +10994,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_forward_ad=False, supports_inplace_autograd=False, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - DecorateInfo( - unittest.expectedFailure, - "TestForeach", - "test_parity", - device_type="cuda", - dtypes=complex_types(), - active_if=lambda kwargs: not kwargs.get("noncontiguous", False), - ), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", + dtypes=complex_types_and(torch.bool)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", + dtypes=complex_types_and(torch.bool)), DecorateInfo( unittest.expectedFailure, "TestForeach", @@ -11201,30 +11034,19 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", dtypes=(torch.bool,)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.bool,)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=(torch.half,), device_type="cpu"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.half,), device_type="cpu"), + dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.half,), device_type="cpu"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.half,), device_type="cpu"), + dtypes=(torch.bool,),), DecorateInfo(unittest.skip("flaky"), "TestForeach", "test_parity", device_type="cpu", dtypes=(torch.complex64,)), DecorateInfo( unittest.expectedFailure, @@ -11255,23 +11077,19 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_inplace_autograd=True, supports_forward_ad=True, decorators=( - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - # Samples have complex types and inplace only works if the dtype is complex. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + # # Samples have complex types and inplace only works if the dtype is complex. + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=(torch.bool,)), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool)), ), ), ForeachFuncInfo( @@ -11283,22 +11101,22 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): decorators=( # Samples have complex types and inplace only works if the dtype is complex. DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool)), # fails with div_cpu is not implemented with ComplexHalf DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types_and(torch.bool)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16)), + dtypes=integral_types() + complex_types_and(torch.bool)), ), ), ] @@ -11387,7 +11205,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): foreach_other_op_db: List[ForeachFuncInfo] = [ ForeachFuncInfo( "lerp", - sample_inputs_func=foreach_lerp_sample_func(3, True, False), + sample_inputs_func=foreach_inputs_sample_func(3, True, False), supports_autograd=True, supports_inplace_autograd=True, supports_forward_ad=True, @@ -11525,7 +11343,7 @@ def _tanh_gelu_ref(X): return _gelu_ref(X) -def reference_one_hot(a: np.ndarray, num_classes: int = -1) -> np.ndarray: +def reference_one_hot(a: npt.NDArray, num_classes: int = -1) -> npt.NDArray: if num_classes == -1: num_classes = int(np.amax(a) + 1) @@ -11545,11 +11363,11 @@ def reference_mse_loss(input, target, reduction="mean"): return se -def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): +def reference_layer_norm(inp: npt.NDArray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5): return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] -def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight, bias, eps): +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: Tuple[int], weight, bias, eps): feature_size = np.prod(normalized_shape) inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] mean = inp_view.mean(axis=-1, keepdims=True) @@ -11566,7 +11384,7 @@ def reference_native_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], w return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) -def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, eps=None): +def reference_rms_norm(inp: npt.NDArray, normalized_shape: Tuple[int], weight=None, eps=None): if eps is None: eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps feature_size = np.prod(normalized_shape) @@ -11578,7 +11396,7 @@ def reference_rms_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=Non return Y.reshape(*inp.shape) -def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=None, eps=1e-5): +def reference_group_norm(inp: npt.NDArray, num_groups: int, weight=None, bias=None, eps=1e-5): inp_view = inp if np.prod(inp.shape) != 0: inp_view = inp.reshape((inp.shape[0], num_groups, -1)) @@ -11664,7 +11482,7 @@ def reference_std_var(f): g = reference_reduction_numpy(f) @wraps(g) - def wrapper(x: np.ndarray, *args, **kwargs): + def wrapper(x: npt.NDArray, *args, **kwargs): assert not ('unbiased' in kwargs and 'correction' in kwargs) if 'unbiased' in kwargs: @@ -13762,8 +13580,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_gradient, error_inputs_func=error_inputs_gradient), OpInfo('isin', - dtypes=all_types(), - dtypesIfCUDA=all_types_and(torch.half), + dtypes=all_types_and(torch.bfloat16, torch.half), supports_autograd=False, sample_inputs_func=sample_inputs_isin), OpInfo('kthvalue', @@ -14893,7 +14710,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_adaptive_avg_pool2d), OpInfo('nn.functional.adaptive_avg_pool3d', dtypes=floating_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), decorators=( # RuntimeError: # adaptive_avg_pool3d(Tensor input, int[3] output_size) -> (Tensor): @@ -14947,8 +14763,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): error_inputs_func=error_inputs_adaptive_max_pool2d, sample_inputs_func=sample_inputs_adaptive_max_pool2d), OpInfo('nn.functional.adaptive_max_pool3d', - dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16), + dtypes=floating_types_and(torch.bfloat16, torch.half), decorators=( # RuntimeError: # adaptive_max_pool3d(Tensor input, int[3] output_size) -> (Tensor): @@ -15272,7 +15087,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), # TF32 DecorateInfo( - toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3)}), + toleranceOverride({torch.float32: tol(atol=5e-3, rtol=1e-3), + torch.complex64: tol(atol=5e-3, rtol=1e-3)}), 'TestCommon', 'test_noncontiguous_samples', ), DecorateInfo( @@ -15402,8 +15218,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='reflect', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.half), sample_inputs_func=partial(sample_inputs_nn_pad, mode='reflect'), skips=( # Doesn't have a corresponding aten operator. @@ -15417,8 +15232,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=partial(sample_inputs_nn_pad, mode='replicate'), skips=( # Doesn't have a corresponding aten operator. @@ -15432,8 +15246,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): variant_test_name='replicate_negative', supports_forward_ad=True, supports_fwgrad_bwgrad=True, - dtypes=all_types_and_complex_and(torch.bfloat16), - dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_nn_pad_replicate_negative, skips=( # Doesn't have a corresponding aten operator. @@ -15477,8 +15290,8 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): autodiff_nonfusible_nodes=["aten::hardswish"]), OpInfo('nn.functional.unfold', aten_name='im2col', - dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_nn_unfold, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, @@ -16261,12 +16074,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # skip for sm < 80 DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', - device_type='cuda', dtypes=(torch.bfloat16,), active_if=not SM80OrLater), # FIXME DecorateInfo(unittest.skip('test_cow_input does not work with efficient attention on ROCM'), 'TestCompositeCompliance', 'test_cow_input', @@ -16280,24 +16087,17 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): 'TestFakeTensor', 'test_fake_crossref_backward_no_amp', device_type='cuda', dtypes=(torch.bfloat16, torch.float16, torch.float32), active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), - # registered in fake_impls.py instead of _meta_registrations.py, so meta kernels will fail. - # However, for implementations that fall back to the constituent ops, the meta kernels may not - # fail. Fused kernels will fail, whereas unfused kernels will not fail. - # All fused kernels support bf16 and fp16 - so if fused attention is supported, the test will fail. - # mem_eff_attention also supports fp32 - so if it is supported the test will fail. - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.bfloat16, torch.float16), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), - # TODO: float32 support in ROCM efficient attention - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", - dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.bfloat16, torch.float16), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", - dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION), + # for element 1, was torch.Size([4, 4, 0]) but real shape was torch.Size([16, 3, 0]) + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16, torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16, torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.bfloat16, torch.float16,), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides", - dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_MEM_EFF_ATTENTION),), + device_type="cuda", dtypes=[torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),), ), OpInfo( 'torch.ops.aten._flash_attention_forward', @@ -16313,14 +16113,14 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False, decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], skips=( - # Device mismatch due to philox seed and offset - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), - # meta implementation is in fake_impls.py instead of being a meta registration - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + # Checking the scalar value of the philox seed and offset # Checking the scalar value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), @@ -16348,14 +16148,14 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], skips=( - # Device mismatch due to philox seed and offset - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake', device_type='cuda'), - # meta implementation is in fake_impls.py instead of being a meta registration - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace"), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace"), + # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16, torch.float32], + active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), + DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", + dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), # Checking the scaler value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), @@ -18520,7 +18320,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('unique', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), - dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.uint16, torch.uint32, torch.uint64), sample_inputs_func=sample_inputs_unique, supports_out=False, supports_autograd=False, @@ -18532,7 +18331,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('unique_consecutive', dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), - dtypesIfCUDA=all_types_and(torch.bool, torch.float16), sample_inputs_func=sample_inputs_unique_consecutive, supports_out=False, supports_autograd=False, @@ -19795,7 +19593,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): sample_inputs_func=sample_inputs_zero_), OpInfo('logsumexp', aliases=('special.logsumexp',), - dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -19822,6 +19620,24 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # vmap does not support inplace views check_inplace_batched_forward_grad=False, sample_inputs_func=sample_inputs_transpose_swapdims), + OpInfo('transpose_copy', + assert_jit_shape_analysis=True, + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + # vmap does not support inplace views + check_inplace_batched_forward_grad=False, + sample_inputs_func=sample_inputs_transpose_swapdims, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), + DecorateInfo( + unittest.expectedFailure, + 'TestJit', + 'test_variant_consistency_jit', + dtypes=(torch.float32,) + ), + )), OpInfo('T', op=lambda x: x.T, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), @@ -20937,7 +20753,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): # FIXME: mean needs 'dim' parameter when using the 'out' overload. # Adding it with 'generate_args_kwargs' does not work, since these also get passed # onto the reference implementations. - supports_out=False, + supports_out=True, assert_autodiffed=True, assert_jit_shape_analysis=True, promotes_int_to_float=True, @@ -20945,6 +20761,9 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ref=reference_reduction_numpy(np.mean), error_inputs_func=error_inputs_mean, skips=( + # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result + # of dtype torch.float32 into an out= with dtype torch.long + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', device_type='cuda', dtypes=[torch.float32]), # FIXME: mean does not support passing keepdim without passing dim DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), # FIXME: mean reduces all dimensions when dim=[] @@ -23995,8 +23814,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.tensor_split", torch_opinfo_name="tensor_split", skips=( - # TensorMeta doesn't support tolist - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), # RuntimeError: no _refs support for torch.Tensor.tolist DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), ), @@ -24035,6 +23852,15 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.transpose", torch_opinfo_name="transpose", ), + PythonRefInfo( + "_refs.transpose_copy", + torch_opinfo_name="transpose_copy", + skips=( + # RuntimeError: no _refs support for torch.Tensor.is_conj + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), + ), + supports_out=True, + ), PythonRefInfo( "_refs.t", torch_opinfo_name="t", diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index bdd0375850b438..63963bab1b050a 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -15,8 +15,8 @@ from torch.testing._internal.common_dtype import ( floating_types, floating_and_complex_types_and, get_all_fp_dtypes) from torch.testing._internal.common_device_type import ( - _TestParametrizer, _update_param_kwargs, toleranceOverride, tol, - skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipMPSVersionIfLessThan, + _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol, + skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipCUDAVersionIn) from torch.testing._internal.common_methods_invocations import DecorateInfo from torch.testing._internal.common_nn import ( @@ -1718,14 +1718,6 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, constructor_input=FunctionInput(3, 6, 1e-3), forward_input=FunctionInput(make_input((4, 6, 2, 3))), desc='2d_affine'), - ModuleInput( - constructor_input=FunctionInput(3, 6, 1e-3), - forward_input=FunctionInput(make_input((4, 6, 28, 28))), - desc='2d_affine_large_feature'), - ModuleInput( - constructor_input=FunctionInput(3, 51, 1e-5, False), - forward_input=FunctionInput(make_input((2, 51, 28, 28))), - desc='2d_no_affine_large_feature'), ModuleInput( constructor_input=FunctionInput(3, 3, 1e-3, False), forward_input=FunctionInput(make_input((4, 3, 2, 3))), @@ -1945,7 +1937,9 @@ def rms_norm_reference_fn(m, p, i): normalized_shape = m.normalized_shape weight = m.weight dims = [ndim - i - 1 for i in range(len(normalized_shape))] - result = i * torch.rsqrt(i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + upcasted_i = i.float() + result = upcasted_i * torch.rsqrt(upcasted_i.pow(2).mean(dim=dims, keepdim=True) + m.eps) + result = result.type_as(i) if weight is not None: result *= weight return result @@ -3361,6 +3355,9 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ] +_macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_macos_or_newer(15, 0) + + # Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ ModuleInfo(torch.nn.AdaptiveAvgPool1d, @@ -3438,9 +3435,6 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm1d, skips=( - # test fails on MPS backend and is being investigated. - # See https://github.com/pytorch/pytorch/issues/100914 - DecorateInfo(skipMPS), # tracking here rather than in the list in test_aotdispatch.py as eval mode passes # RuntimeError: tried to get Double out of SymInt DecorateInfo( @@ -3459,9 +3453,8 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm2d, skips=( - # test fails on MPS backend and is being investigated. - # See https://github.com/pytorch/pytorch/issues/100914 - DecorateInfo(skipMPS), + # See https://github.com/pytorch/pytorch/issues/134580 + DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')), # tracking here rather than in the list in test_aotdispatch.py as eval mode passes # RuntimeError: tried to get Double out of SymInt DecorateInfo( @@ -4069,11 +4062,15 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Hardswish, module_inputs_func=module_inputs_torch_nn_Hardswish, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),), supports_gradgrad=False), ModuleInfo(torch.nn.Hardtanh, @@ -4188,11 +4185,15 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.ReLU, module_inputs_func=module_inputs_torch_nn_ReLU, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.LeakyReLU, @@ -4226,11 +4227,15 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Sigmoid, module_inputs_func=module_inputs_torch_nn_Sigmoid, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.LogSigmoid, @@ -4285,20 +4290,28 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Tanh, module_inputs_func=module_inputs_torch_nn_Tanh, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.Tanhshrink, module_inputs_func=module_inputs_torch_nn_Tanhshrink, - skips=( + skips=None if _macos15_or_newer else ( # Fails on backward check on MPS # See https://github.com/pytorch/pytorch/issues/107214 DecorateInfo( - skipMPSVersionIfLessThan(15, 0) + unittest.expectedFailure, + 'TestModule', + 'test_memory_format', + active_if=operator.itemgetter('training'), + device_type='mps', ),) ), ModuleInfo(torch.nn.Threshold, diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 7d4d225c8e4012..bd8f1f2963f525 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -132,6 +132,8 @@ def __init__( ), # the optim supports passing in sparse gradients as well as dense grads supports_sparse: bool = False, + # the optimizer constructor supports passing in capturable as a kwarg + has_capturable_arg: bool = False, # the optim only supports one config: sparse grads w/ dense params, see SparseAdam only_supports_sparse_grads: bool = False, # Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests, @@ -157,6 +159,7 @@ def __init__( self.supported_impls = supported_impls self.not_og_supported_flags = not_og_supported_flags self.supports_sparse = supports_sparse + self.has_capturable_arg = has_capturable_arg self.metadata_for_sparse = metadata_for_sparse self.only_supports_sparse_grads = only_supports_sparse_grads self.supports_complex = supports_complex @@ -330,10 +333,11 @@ def optim_inputs_func_adadelta(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", + desc="maximize, weight_decay", ), OptimizerInput( params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho" @@ -631,9 +635,14 @@ def optim_inputs_func_adamax(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, + kwargs={"maximize": True}, desc="maximize", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize, weight_decay", + ), ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) @@ -788,9 +797,16 @@ def optim_inputs_func_nadam(device, dtype=None): ), OptimizerInput( params=None, - kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + kwargs={ + "weight_decay": 0.1, + }, desc="weight_decay", ), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3}, + desc="weight_decay, momentum_decay", + ), OptimizerInput( params=None, kwargs={ @@ -933,11 +949,26 @@ def optim_inputs_func_rmsprop(device, dtype=None): OptimizerInput( params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay" ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + }, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True}, desc="centered", ), + OptimizerInput( + params=None, + kwargs={ + "maximize": True, + "weight_decay": 0.1, + }, + desc="maximize, weight_decay", + ), OptimizerInput( params=None, kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1}, @@ -951,7 +982,7 @@ def optim_inputs_func_rmsprop(device, dtype=None): "momentum": 0.1, "maximize": True, }, - desc="maximize", + desc="maximize, centered, weight_decay, w/ momentum", ), ] + (cuda_supported_configs if _get_device_type(device) == "cuda" else []) @@ -1022,7 +1053,15 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr" ), + OptimizerInput( + params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay" + ), OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"), + OptimizerInput( + params=None, + kwargs={"weight_decay": 0.1, "maximize": True}, + desc="maximize", + ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "dampening": 0.5}, @@ -1031,18 +1070,13 @@ def optim_inputs_func_sgd(device, dtype=None): OptimizerInput( params=None, kwargs={"momentum": 0.9, "weight_decay": 0.1}, - desc="non-zero weight_decay", + desc="weight_decay w/ momentum", ), OptimizerInput( params=None, kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1}, desc="nesterov", ), - OptimizerInput( - params=None, - kwargs={"weight_decay": 0.1, "maximize": True}, - desc="maximize", - ), ] @@ -1208,6 +1242,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adadelta, optim_error_inputs_func=optim_error_inputs_func_adadelta, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1493,6 +1528,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( ), optim_error_inputs_func=optim_error_inputs_func_adam, supported_impls=("foreach", "differentiable", "fused"), + has_capturable_arg=True, not_og_supported_flags=( "foreach", "differentiable", @@ -1578,6 +1614,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_adamax, optim_error_inputs_func=optim_error_inputs_func_adamax, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1630,6 +1667,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "capturable", ), supports_fused_on=("cpu", "cuda", "mps"), + has_capturable_arg=True, decorators=( # Expected error between compiled forloop and fused optimizers DecorateInfo( @@ -1710,6 +1748,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_asgd, optim_error_inputs_func=optim_error_inputs_func_asgd, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1822,6 +1861,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_nadam, optim_error_inputs_func=optim_error_inputs_func_nadam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1870,6 +1910,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_radam, optim_error_inputs_func=optim_error_inputs_func_radam, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"), @@ -1915,6 +1956,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rmsprop, optim_error_inputs_func=optim_error_inputs_func_rmsprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115 @@ -1964,6 +2006,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( optim_inputs_func=optim_inputs_func_rprop, optim_error_inputs_func=optim_error_inputs_func_rprop, supported_impls=("foreach", "differentiable"), + has_capturable_arg=True, skips=( DecorateInfo( skipIfMps, # Rprop doesn't update for non-contiguous, see #118117 diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 87891bcefdfc1a..ce990cd0aaf8ef 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -1247,6 +1247,7 @@ def _test_quantizer( export_with_dynamic_shape=False, is_qat=False, is_debug_mode=False, + capture_pre_autograd_graph_node_occurrence=None, ): # resetting dynamo cache torch._dynamo.reset() @@ -1305,6 +1306,10 @@ def _test_quantizer( for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): if k in expected_node_occurrence: node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] + if capture_pre_autograd_graph_node_occurrence is not None: + node_occurrence = { + ns.call_function(k): v for k, v in capture_pre_autograd_graph_node_occurrence.items() + } self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) fx_quant_output = m_fx(*example_inputs) self.assertEqual(fx_quant_output, pt2_quant_output) diff --git a/torch/testing/_internal/common_subclass.py b/torch/testing/_internal/common_subclass.py index f6a8ed065cb819..3c76e19fab4eb6 100644 --- a/torch/testing/_internal/common_subclass.py +++ b/torch/testing/_internal/common_subclass.py @@ -3,6 +3,8 @@ import torch from copy import deepcopy from torch.utils._pytree import tree_map +import torch.utils._pytree as pytree + # TODO: Move LoggingTensor here. from torch.testing._internal.logging_tensor import LoggingTensor @@ -216,3 +218,49 @@ def __init__(self, name, create_fn, closed_under_ops=True): closed_under_ops=False # sparse semantics ), } + +class SubclassWithTensorFactory(torch.Tensor): + @staticmethod + def __new__(cls, src): + shape = src.shape + kwargs = {} + kwargs["strides"] = src.stride() + kwargs["storage_offset"] = src.storage_offset() + kwargs["device"] = src.device + kwargs["layout"] = src.layout + kwargs["requires_grad"] = src.requires_grad + kwargs["dtype"] = src.dtype + out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + return out + + def __init__(self, src): + self.src = src + + def __repr__(self): + return f"{self.__class__.__name__}" + + def __tensor_flatten__(self): + return ["src"], None + + @classmethod + def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride): + src = inner_tensors["src"] + return cls(src) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + + def _fn(x): + return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src + + _args = pytree.tree_map_only(cls, _fn, args) + _kwargs = pytree.tree_map_only(cls, _fn, kwargs) + + _out = func(*_args, **_kwargs) + + _out_flat, _out_spec = pytree.tree_flatten(_out) + + out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat] + return pytree.tree_unflatten(out_flat, _out_spec) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 764772cc767989..ea3485249bb020 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -103,6 +103,10 @@ except ImportError: has_pytest = False + +MI300_ARCH = ("gfx940", "gfx941", "gfx942") + + def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -1503,6 +1507,10 @@ def xfailIfTorchDynamo(func): return unittest.expectedFailure(func) if TEST_WITH_TORCHDYNAMO else func +def xfailIfLinux(func): + return unittest.expectedFailure(func) if IS_LINUX and not TEST_WITH_ROCM else func + + def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): """ Usage: @@ -1759,6 +1767,19 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest("test currently only works on the ROCm stack") return wrapper +def runOnRocmArch(arch: Tuple[str, ...]): + def dec_fn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if TEST_WITH_ROCM: + prop = torch.cuda.get_device_properties(0) + if prop.gcnArchName.split(":")[0] not in arch: + reason = f"skipIfRocm: test only runs on {arch}" + raise unittest.SkipTest(reason) + return fn(self, *args, **kwargs) + return wrap_fn + return dec_fn + def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"): def dec_fn(fn): reason = f"skipIfXpu: {msg}" diff --git a/torch/testing/_internal/distributed/common_state_dict.py b/torch/testing/_internal/distributed/common_state_dict.py index 5aa8d7f14aa0ab..c2ab9af9f197e8 100644 --- a/torch/testing/_internal/distributed/common_state_dict.py +++ b/torch/testing/_internal/distributed/common_state_dict.py @@ -43,7 +43,12 @@ def _verify_msd( dist_param = dist_msd.get(fqn, None) if not options.ignore_frozen_params: self.assertIsNotNone(dist_param, f"{fqn=}") - self._compare_tensor(param, dist_param, offload_to_cpu) + try: + self._compare_tensor(param, dist_param, offload_to_cpu) + except AssertionError as e: + raise AssertionError( + f"{fqn} has mismatched value {param} {dist_param}" + ) from e elif dist_param is None: self.assertFalse(param.requires_grad, f"{fqn=}") diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index de8b51e3defb98..00a67b20a73c7c 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -18,7 +18,7 @@ requires_gloo, requires_nccl, skip_if_lt_x_gpu, - skip_if_rocm, + skip_if_rocm_multiprocess, ) from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( @@ -662,7 +662,7 @@ class CudaDdpComparisonTest(CommonDdpComparisonTest): @skip_if_lt_x_gpu(NUM_TRAINERS) @requires_nccl() @dist_init - @skip_if_rocm + @skip_if_rocm_multiprocess def test_ddp_dist_autograd_local_vs_remote_gpu(self): # Each trainer uses a different random seed. Otherwise, they are going # to have exactly the same initial model parameters, input, and diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index b20165cd98fac4..981c8e59580649 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -62,7 +62,7 @@ initialize_temp_directories, cleanup_temp_dir, simple_sparse_reduce_tests, - skip_if_rocm, + skip_if_rocm_multiprocess, skip_if_small_worldsize, skip_if_odd_worldsize, skip_if_lt_x_gpu, @@ -3936,7 +3936,7 @@ def test_all_to_all(self): @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_cuda(self): group, group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -3952,7 +3952,7 @@ def test_all_to_all_complex(self): @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_cuda_complex(self): group, group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -4020,7 +4020,7 @@ def test_all_to_all_group(self): BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single" ) @skip_if_small_worldsize - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_group_cuda(self): group, group_id, rank = self._init_group_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -4080,7 +4080,7 @@ def test_all_to_all_full_group(self): @skip_but_pass_in_sandcastle_if( BACKEND != "nccl", "Only NCCL supports CUDA all_to_all" ) - @skip_if_rocm + @skip_if_rocm_multiprocess def test_all_to_all_full_group_cuda(self): group, group_id, rank = self._init_full_group_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 172c9efb8800ba..fa352cb5a3777e 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -61,6 +61,7 @@ def f2(x, y0, y1): "call_torchbind", "triton_kernel_wrapper_mutation", "triton_kernel_wrapper_functional", + "hints_wrapper", ] torch.library.define( diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 7e952e929feefa..00f9ba0dd2ee29 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -1,5 +1,6 @@ # mypy: ignore-errors +import logging import torch import re import unittest @@ -22,6 +23,8 @@ IS_WINDOWS, ) +log: logging.Logger = logging.getLogger(__name__) + def test_cpu(): try: CppCodeCache.load("") @@ -72,6 +75,11 @@ def skipDeviceIf(cond, msg, *, device): if cond: def decorate_fn(fn): def inner(self, *args, **kwargs): + if not hasattr(self, "device"): + warn_msg = "Expect the test class to have attribute device but not found. " + if hasattr(self, "device_type"): + warn_msg += "Consider using the skip device decorators in common_device_type.py" + log.warning(warn_msg) if self.device == device: raise unittest.SkipTest(msg) return fn(self, *args, **kwargs) diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index f5290c7643ce14..eda339ebfe68a6 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -1182,7 +1182,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ), ReductionOpInfo( "masked.logsumexp", - dtypes=all_types_and(torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), method_variant=None, nan_policy="propagate", supports_out=False, diff --git a/torch/testing/_internal/opinfo/definitions/nested.py b/torch/testing/_internal/opinfo/definitions/nested.py index ea678c2e4f8750..9654037fa70ad1 100644 --- a/torch/testing/_internal/opinfo/definitions/nested.py +++ b/torch/testing/_internal/opinfo/definitions/nested.py @@ -249,6 +249,31 @@ def sample_inputs_masked_select( ) +def sample_inputs_nn_functional_rms_norm( + op_info, device, dtype, requires_grad, **kwargs +): + for njt in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4] + ): + # normalize over non-ragged dims + for start_dim in range(2, njt.dim()): + normalized_shape = njt.shape[start_dim:] + weight = torch.randn( + normalized_shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + yield SampleInput( + njt, + kwargs={ + "normalized_shape": normalized_shape, + "weight": weight, + }, + ) + + sample_inputs_nn_functional_threshold = partial( sample_inputs_elementwise_njt_unary, op_kwargs={"threshold": float.fromhex("0x1.3ap-3"), "value": -9}, @@ -264,6 +289,7 @@ def sample_inputs_masked_select( njt_sample_inputs = { "clone": sample_inputs_clone, **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)}, + "nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm, "nn.functional.threshold": sample_inputs_nn_functional_threshold, **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)}, "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0), diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py index 41973dc2c0518a..05468e10da2c90 100644 --- a/torch/testing/_internal/opinfo/utils.py +++ b/torch/testing/_internal/opinfo/utils.py @@ -6,6 +6,7 @@ from typing import Sequence import numpy as np +import numpy.typing as npt import torch from torch.testing._internal.common_cuda import TEST_CUDA @@ -206,7 +207,7 @@ def reference_reduction_numpy(f, supports_keepdims=True): """ @wraps(f) - def wrapper(x: np.ndarray, *args, **kwargs): + def wrapper(x: npt.NDArray, *args, **kwargs): # Copy keys into a set keys = set(kwargs.keys()) diff --git a/torch/testing/_internal/optests/aot_autograd.py b/torch/testing/_internal/optests/aot_autograd.py index 4f281c7771757c..a5552e23c8a467 100644 --- a/torch/testing/_internal/optests/aot_autograd.py +++ b/torch/testing/_internal/optests/aot_autograd.py @@ -118,7 +118,9 @@ def check(args, ignore_failure=False): raise # See https://github.com/pytorch/pytorch/pull/98960#issuecomment-1505962215 - if all(x is None for x in orig_grad): + tensor_args = [x for x in pytree.tree_flatten(args)[0] if isinstance(x, torch.Tensor)] + any_non_leaves = any(x.grad_fn is not None for x in tensor_args) + if all(x is None for x in orig_grad) and any_non_leaves: with assert_raises_regex_fn(RuntimeError, 'does not require grad and does not have a grad_fn'): call_forwards_backwards(compiled_f, args) return diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 1ac2dfec09d6bd..d3a8065f294047 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -78,6 +78,31 @@ def add_kernel_autotuned( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2), + ], + key=[], + ) + @triton.jit + def add_kernel_autotuned_weird_param_order( + in_ptr0, + in_ptr1, + n_elements, + BLOCK_SIZE: "tl.constexpr", + out_ptr, + ): + # out_ptr is after an autotuned param that's declared as tl.constexpr. + # This param ordering can create bugs if not handled correctly. + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( configs=[ triton.Config( diff --git a/torch/types.py b/torch/types.py index 7725716909196e..536af3eef8c9d1 100644 --- a/torch/types.py +++ b/torch/types.py @@ -1,7 +1,5 @@ # mypy: allow-untyped-defs -import builtins - # In some cases, these basic types are shadowed by corresponding # top-level values. The underscore variants let us refer to these # types. See https://github.com/python/mypy/issues/4146 for why these @@ -14,100 +12,114 @@ int as _int, str as _str, ) -from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union - -import torch -from torch import SymBool, SymFloat, SymInt +from typing import Any, Dict, List, Sequence, Tuple, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +# `as` imports have better static analysis support than assignment `ExposedType: TypeAlias = HiddenType` +from torch import ( # noqa: F401 + device as _device, + DispatchKey as DispatchKey, + dtype as _dtype, + layout as _layout, + qscheme as _qscheme, + Size as Size, + SymBool as SymBool, + SymFloat as SymFloat, + SymInt as SymInt, + Tensor as Tensor, +) if TYPE_CHECKING: from torch.autograd.graph import GradientEdge + __all__ = ["Number", "Device", "Storage"] # Convenience aliases for common composite types that we need # to talk about in PyTorch - -_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]] -_TensorOrTensorsOrGradEdge = Union[ - torch.Tensor, - Sequence[torch.Tensor], +_TensorOrTensors: TypeAlias = Union[Tensor, Sequence[Tensor]] # noqa: PYI047 +_TensorOrTensorsOrGradEdge: TypeAlias = Union[ # noqa: PYI047 + Tensor, + Sequence[Tensor], "GradientEdge", Sequence["GradientEdge"], ] -_dtype = torch.dtype -_device = torch.device -_qscheme = torch.qscheme -_layout = torch.layout -_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]] -_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]] -_dispatchkey = Union[builtins.str, torch._C.DispatchKey] +_size: TypeAlias = Union[Size, List[int], Tuple[int, ...]] # noqa: PYI042,PYI047 +_symsize: TypeAlias = Union[Size, Sequence[Union[int, SymInt]]] # noqa: PYI042,PYI047 +_dispatchkey: TypeAlias = Union[str, DispatchKey] # noqa: PYI042,PYI047 # int or SymInt -IntLikeType = Union[_int, torch.SymInt] +IntLikeType: TypeAlias = Union[int, SymInt] +# float or SymFloat +FloatLikeType: TypeAlias = Union[float, SymFloat] +# bool or SymBool +BoolLikeType: TypeAlias = Union[bool, SymBool] py_sym_types = (SymInt, SymFloat, SymBool) -PySymType = Union[SymInt, SymFloat, SymBool] +PySymType: TypeAlias = Union[SymInt, SymFloat, SymBool] # Meta-type for "numeric" things; matches our docs -Number = Union[builtins.int, builtins.float, builtins.bool] +Number: TypeAlias = Union[int, float, bool] # Meta-type for "device-like" things. Not to be confused with 'device' (a # literal device object). This nomenclature is consistent with PythonArgParser. # None means use the default device (typically CPU) -Device = Optional[Union[_device, builtins.str, builtins.int]] -del Optional - -# Storage protocol implemented by ${Type}StorageBase classes +Device: TypeAlias = Union[_device, str, int, None] +# Storage protocol implemented by ${Type}StorageBase classes class Storage: - _cdata: _int - device: torch.device - dtype: torch.dtype - _torch_load_uninitialized: _bool + _cdata: int + device: _device + dtype: _dtype + _torch_load_uninitialized: bool - def __deepcopy__(self, memo: dict) -> "Storage": + def __deepcopy__(self, memo: Dict[int, Any]) -> "Storage": raise NotImplementedError - def _new_shared(self, size: _int) -> "Storage": + def _new_shared(self, size: int) -> "Storage": raise NotImplementedError def _write_file( self, f: Any, - is_real_file: _bool, - save_size: _bool, - element_size: _int, + is_real_file: bool, + save_size: bool, + element_size: int, ) -> None: raise NotImplementedError - def element_size(self) -> _int: + def element_size(self) -> int: raise NotImplementedError - def is_shared(self) -> _bool: + def is_shared(self) -> bool: raise NotImplementedError def share_memory_(self) -> "Storage": raise NotImplementedError - def nbytes(self) -> _int: + def nbytes(self) -> int: raise NotImplementedError def cpu(self) -> "Storage": raise NotImplementedError - def data_ptr(self) -> _int: + def data_ptr(self) -> int: raise NotImplementedError def from_file( self, - filename: _str, - shared: _bool = False, - nbytes: _int = 0, + filename: str, + shared: bool = False, + nbytes: int = 0, ) -> "Storage": raise NotImplementedError - def _new_with_file(self, f: Any, element_size: _int) -> "Storage": + def _new_with_file( + self, + f: Any, + element_size: int, + ) -> "Storage": raise NotImplementedError diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index eea7f058bfc5e0..70c65a69071733 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -3,7 +3,7 @@ import warnings from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque +from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type from typing_extensions import TypeGuard from collections import deque @@ -391,6 +391,11 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: and hasattr(t, "__tensor_unflatten__") ) +def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]: + """Same as above, but takes a type argument instead of an instance.""" + return (issubclass(t, torch.Tensor) and t != torch.Tensor + and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")) + def transform_subclass(t, callback, outer_size=None, outer_stride=None): """ diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index d54495047e2713..493352798eab8c 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -3,6 +3,17 @@ import math import operator import sys +from typing import ( + Any, + Callable, + Iterable, + List, + Optional, + SupportsFloat, + Tuple, + TypeVar, + Union, +) import sympy from sympy import S @@ -19,6 +30,8 @@ from .numbers import int_oo +_T = TypeVar("_T", bound=SupportsFloat) + # Portions of this file are adapted from the Sympy codebase, which was # licensed as follows: # @@ -53,13 +66,20 @@ __all__ = [ "FloorDiv", "ModularIndexing", + "Where", + "PythonMod", + "Mod", "CleanDiv", + "CeilToInt", + "FloorToInt", "CeilDiv", "IntTrueDiv", "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", + "TruncToFloat", + "TruncToInt", "RoundToInt", "RoundDecimal", "ToFloat", @@ -69,9 +89,9 @@ ] -def _keep_float(f): +def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]: @functools.wraps(f) - def inner(*args): + def inner(*args: Any) -> Union[_T, sympy.Float]: r = f(*args) if any(isinstance(a, sympy.Float) for a in args) and not isinstance( r, sympy.Float @@ -82,12 +102,56 @@ def inner(*args): return inner -def fuzzy_eq(x, y): +def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]: if None in (x, y): return None return x == y +def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic: + """ + Fast path for sympy.gcd, using a simple factoring strategy. + + We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, + where n is the greatest common integer factor and e is the largest + syntactic common factor (i.e., common sub-expression) in p and q. + Then the gcd returned is n*e, cancelling which we would be left with + p1 + p2 and q0. + + Note that further factoring of p1 + p2 and q0 might be possible with + sympy.factor (which uses domain-specific theories). E.g., we are unable + to find that x*y + x + y + 1 is divisible by x + 1. More generally, + when q is of the form q1 + q2 (instead of being already factored) it + might be necessary to fall back on sympy.gcd. + """ + + def integer_coefficient(x: sympy.Basic) -> int: + integer_coefficients: List[int] = [ + abs(int(arg)) + for arg in sympy.Mul.make_args(x) + if isinstance(arg, (int, sympy.Integer)) + ] + return math.prod(integer_coefficients) + + def integer_factor(expr: sympy.Basic) -> int: + integer_factors: Iterable[int] = map( + integer_coefficient, sympy.Add.make_args(expr) + ) + return functools.reduce(math.gcd, integer_factors) + + gcd: int = math.gcd(integer_factor(p), integer_factor(q)) + p, q = p / gcd, q / gcd + + base_splits: List[Tuple[sympy.Basic, ...]] = list( + map(sympy.Mul.make_args, sympy.Add.make_args(p)) + ) + divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q) + for x in divisor_split: + if all(x in base_split for base_split in base_splits): + gcd = gcd * x + return gcd + + # It would be nice to have assertions on whether or not inputs is_integer # However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy # sometimes inconsistently reports floats an integers. @@ -115,20 +179,19 @@ class FloorDiv(sympy.Function): NB: This is Python-style floor division, round to -Inf """ - nargs = (2,) - precedence = 50 # precedence of mul # noqa: F811 - - is_integer = True + nargs: Tuple[int, ...] = (2,) + precedence: int = 50 # precedence of mul # noqa: F811 + is_integer: bool = True @property - def base(self): + def base(self) -> sympy.Basic: return self.args[0] @property - def divisor(self): + def divisor(self) -> sympy.Basic: return self.args[1] - def _sympystr(self, printer): + def _sympystr(self, printer: sympy.printing.printer.Printer) -> str: base = printer.parenthesize(self.base, self.precedence) divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" @@ -136,7 +199,7 @@ def _sympystr(self, printer): # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod - def eval(cls, base, divisor): + def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]: # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full # Assert triggered by inequality solver # assert base.is_integer, base @@ -186,8 +249,7 @@ def eval(cls, base, divisor): # Expands (x + y) // b into x // b + y // b. # This only works if floor is an identity, i.e. x / b is an integer. - base_args = sympy.Add.make_args(base) - for term in base_args: + for term in sympy.Add.make_args(base): quotient = term / divisor if quotient.is_integer and isinstance(divisor, sympy.Integer): # NB: this is correct even if the divisor is not an integer, but it @@ -196,28 +258,31 @@ def eval(cls, base, divisor): return FloorDiv(base - term, divisor) + quotient try: - # sympy.gcd tends to blow up on large sums, so use it on each summand instead - gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args) - if not equal_valued(gcd, 1) and all( - equal_valued(gcd, gcd_) for gcd_ in gcds_ - ): + gcd = simple_floordiv_gcd(base, divisor) + if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add): + gcd = sympy.gcd(base, divisor) + if not equal_valued(gcd, 1): return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) ) except sympy.PolynomialError: pass # https://github.com/pytorch/pytorch/issues/108276 + return None + class ModularIndexing(sympy.Function): """ ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus """ - nargs = (3,) - is_integer = True + nargs: Tuple[int, ...] = (3,) + is_integer: bool = True @classmethod - def eval(cls, base, divisor, modulus): + def eval( + cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic + ) -> Optional[sympy.Basic]: if base == 0 or modulus == 1: return sympy.Integer(0) @@ -241,8 +306,8 @@ def eval(cls, base, divisor, modulus): pass # https://github.com/pytorch/pytorch/issues/108276 if isinstance(base, sympy.Add): - new_terms = [] - all_positive = True + new_terms: List[sympy.Basic] = [] + all_positive: bool = True for term in base.args: if sympy.gcd(term, modulus * divisor) != modulus * divisor: if (isinstance(term, sympy.Integer) and term < 0) or ( @@ -265,11 +330,13 @@ def eval(cls, base, divisor, modulus): if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) - def _eval_is_nonnegative(self): + return None + + def _eval_is_nonnegative(self) -> Optional[bool]: p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] - def _eval_is_positive(self): + def _eval_is_positive(self) -> Optional[bool]: p, q = self.args[:2] return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] @@ -279,37 +346,40 @@ class Where(sympy.Function): Good ol' ternary operator """ - nargs = (3,) + nargs: Tuple[int, ...] = (3,) - def _eval_is_integer(self): + def _eval_is_integer(self) -> Optional[bool]: return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self) -> Optional[bool]: return ( True if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] else None ) - def _eval_is_positive(self): + def _eval_is_positive(self) -> Optional[bool]: return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] @classmethod - def eval(cls, c, p, q): + def eval( + cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic + ) -> Optional[sympy.Basic]: if c == sympy.true: return p elif c == sympy.false: return q + return None # Python-style modulus: take sign from RHS class PythonMod(sympy.Function): - nargs = (2,) + nargs: Tuple[int, ...] = (2,) - is_integer = True + is_integer: bool = True @classmethod - def eval(cls, p, q): + def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]: # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint # Triggered by sympy.solvers.inequalities.reduce_inequalities # assert p.is_integer, p @@ -351,11 +421,13 @@ def eval(cls, p, q): if sympy.Mod(p, q) == 0: return S.Zero + return None + # NB: args[1] for PythonMod - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self) -> Optional[bool]: return True if self.args[1].is_positive else None # type: ignore[attr-defined] - def _eval_is_nonpositive(self): + def _eval_is_nonpositive(self) -> Optional[bool]: return True if self.args[1].is_negative else None # type: ignore[attr-defined] @@ -778,13 +850,13 @@ class Max(MinMaxBase, Application): # type: ignore[misc] zero = S.Infinity identity = S.NegativeInfinity - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): + def _eval_is_negative(self): # type:ignore[override] return fuzzy_and(a.is_negative for a in self.args) @@ -796,13 +868,13 @@ class Min(MinMaxBase, Application): # type: ignore[misc] zero = S.NegativeInfinity identity = S.Infinity - def _eval_is_positive(self): + def _eval_is_positive(self): # type:ignore[override] return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined] - def _eval_is_nonnegative(self): + def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] - def _eval_is_negative(self): + def _eval_is_negative(self): # type:ignore[override] return fuzzy_or(a.is_negative for a in self.args) diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 11fcded3452408..f845484db489e1 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import math import operator +from typing import Union import sympy @@ -281,3 +282,202 @@ def round_to_int(a, dtype): @staticmethod def round_decimal(a, b): return round(a, ndigits=b) + + +def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return torch.ops.aten._to_copy(x, dtype=dtype) + + +# Suppose we have some int/float arguments. This diagram commutes: +# +# int/float -- PythonReferenceAnalysis.op --> int/float +# | | +# | | +# torch.tensor(..., dtype=torch.int64/torch.float64) +# | | +# V V +# Tensor -- TensorReferenceAnalysis.op --> Tensor +# +# NB: int before and after must be representable in int64 (we will +# insert guards accordingly.) +# +# This is guaranteed to be FX traceable with OpOverloads only. +class TensorReferenceAnalysis: + # NB: This is actually dead, because with Proxy tracing the factory + # function isn't traced correctly. Here for completeness. + @staticmethod + def constant(c, dtype): + d: Union[int, float, bool] + if dtype is torch.int64: + d = int(c) + elif dtype is torch.double: + d = float(c) + elif dtype is torch.bool: + d = bool(c) + else: + raise AssertionError(f"unrecognized dtype {dtype}") + return torch.ops.aten.scalar_tensor.default(d, dtype=dtype) + + @staticmethod + def or_(a, b): + return torch.ops.aten.logical_or.default(a, b) + + @staticmethod + def and_(a, b): + return torch.ops.aten.logical_and.default(a, b) + + @staticmethod + def eq(a, b): + return torch.ops.aten.eq.Tensor(a, b) + + @classmethod + def ne(cls, a, b): + return torch.ops.aten.ne.Tensor(a, b) + + @staticmethod + def lt(a, b): + return torch.ops.aten.lt.Tensor(a, b) + + @staticmethod + def gt(a, b): + return torch.ops.aten.gt.Tensor(a, b) + + @staticmethod + def le(a, b): + return torch.ops.aten.le.Tensor(a, b) + + @staticmethod + def ge(a, b): + return torch.ops.aten.ge.Tensor(a, b) + + @staticmethod + def not_(a): + return torch.ops.aten.logical_not.default(a) + + @staticmethod + def reciprocal(x): + return torch.ops.aten.reciprocal.default(x) + + @staticmethod + def square(x): + # TODO: maybe composite implicit autograd doesn't work here? + return torch.ops.aten.square.default(x) + + @staticmethod + def trunc_to_int(x, dtype): + return _to_dtype(torch.ops.aten.trunc.default(x), dtype) + + @staticmethod + def ceil_to_int(x, dtype): + return _to_dtype(torch.ops.aten.ceil.default(x), dtype) + + @staticmethod + def floor_to_int(x, dtype): + return _to_dtype(torch.ops.aten.floor.default(x), dtype) + + @staticmethod + def floor(x): + return torch.ops.aten.floor.default(x) + + @staticmethod + def ceil(x): + return torch.ops.aten.ceil.default(x) + + @staticmethod + def to_dtype(x, dtype): + return _to_dtype(x, dtype) + + @staticmethod + def mod(x, y): + # TODO: https://github.com/pytorch/pytorch/pull/133654 + raise NotImplementedError( + "no C-style modulus operation available from frontend atm" + ) + + @staticmethod + def abs(x): + return torch.ops.aten.abs.default(x) + + @staticmethod + def neg(x): + return torch.ops.aten.neg.default(x) + + @staticmethod + def truediv(a, b): + return torch.ops.aten.true_divide.Tensor(a, b) + + @staticmethod + def int_truediv(a, b): + raise NotImplementedError( + "Python int truediv difficult to implement in PyTorch atm" + ) + + # TODO: This is wrong, CPython has a custom implementation of true + # division that results in higher precision when the floats are + # sufficiently large. Short term fix: add a guard here + return torch.ops.aten.true_divide.default( + _to_dtype(a, torch.float64), _to_dtype(b, torch.float64) + ) + + @staticmethod + def floordiv(a, b): + return torch.ops.aten.floor_divide(a, b) + + @staticmethod + def truncdiv(a, b): + raise NotImplementedError( + "no C-style truncdiv operation available from frontend atm" + ) + + @staticmethod + def add(a, b): + return torch.ops.aten.add.Tensor(a, b) + + @staticmethod + def mul(a, b): + return torch.ops.aten.mul.Tensor(a, b) + + @staticmethod + def sub(a, b): + return torch.ops.aten.sub.Tensor(a, b) + + @staticmethod + def exp(x): + return torch.ops.aten.exp.default(x) + + @staticmethod + def log(x): + return torch.ops.aten.log.default(x) + + @staticmethod + def sqrt(x): + return torch.ops.aten.sqrt.default(x) + + @staticmethod + def pow(a, b): + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def pow_by_natural(a, b): + # NB: pow handles int x int fine + return torch.ops.aten.pow.Tensor_Tensor(a, b) + + @staticmethod + def minimum(a, b): + return torch.ops.aten.minimum.default(a, b) + + @staticmethod + def maximum(a, b): + return torch.ops.aten.maximum.default(a, b) + + @staticmethod + def round_to_int(a, dtype): + return torch.ops.aten.round.default(a) + + @staticmethod + def round_decimal(a, b): + raise NotImplementedError( + "round decimal doesn't support Tensor second argument atm" + ) + + # return torch.ops.aten.round.decimals(a, b) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 57d509323e01f4..2dbad5241d039d 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -87,7 +87,9 @@ def simple_sympify(e): def sympy_generic_le(lower, upper): if isinstance(lower, sympy.Expr): assert isinstance(upper, sympy.Expr) - return lower <= upper + # instead of lower <= upper, we do upper >= lower since upper is mostly int_oo + # and we have better code paths there. + return upper >= lower else: # only negative condition is True > False assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), ( @@ -552,9 +554,9 @@ def mul(cls, a, b): def safe_mul(a, b): # Make unknown() * wrap(0.0) == wrap(0.0) - if a == 0.0: + if a == 0.0 or a == 0: return a - elif b == 0.0: + elif b == 0.0 or b == 0: return b else: return a * b @@ -590,7 +592,7 @@ def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b: - return ValueRanges.unknown() + return ValueRanges.unknown_int() products = [] for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]): r = FloorDiv(x, y) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 22d616b83faa47..94a8744e5c47a2 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -433,7 +433,7 @@ def checkpoint( use_reentrant(bool): specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed - explicitly. In version 2.4 we will raise an exception if + explicitly. In version 2.5 we will raise an exception if ``use_reentrant`` is not passed. If ``use_reentrant=False``, ``checkpoint`` will use an implementation that does not require reentrant autograd. This allows ``checkpoint`` to support additional @@ -464,7 +464,7 @@ def checkpoint( if use_reentrant is None: warnings.warn( "torch.utils.checkpoint: the use_reentrant parameter should be " - "passed explicitly. In version 2.4 we will raise an exception " + "passed explicitly. In version 2.5 we will raise an exception " "if use_reentrant is not passed. use_reentrant=False is " "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " @@ -533,7 +533,7 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar use_reentrant(bool): specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed - explicitly. In version 2.4 we will raise an exception if + explicitly. In version 2.5 we will raise an exception if ``use_reentrant`` is not passed. If ``use_reentrant=False``, ``checkpoint`` will use an implementation that does not require reentrant autograd. This allows ``checkpoint`` to support additional @@ -553,7 +553,7 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar warnings.warn( "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " "parameter should be passed explicitly. " - "In version 2.4 we will raise an exception if use_reentrant " + "In version 2.5 we will raise an exception if use_reentrant " "is not passed. use_reentrant=False is " "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index aaa45ea4c909a3..4c7366193cc5c8 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -140,6 +140,21 @@ def _find_rocm_home() -> Optional[str]: file=sys.stderr) return rocm_home +def _find_sycl_home() -> Optional[str]: + """Find the OneAPI install path.""" + # Guess #1 + sycl_home = os.environ.get('ONEAPI_ROOT') + if sycl_home is None: + # Guess #2 + icpx_path = shutil.which('icpx') + if icpx_path is not None: + sycl_home = os.path.dirname(os.path.dirname( + os.path.realpath(icpx_path))) + + if sycl_home and not torch.xpu.is_available(): + print(f"No XPU runtime is found, using ONEAPI_ROOT='{sycl_home}'", + file=sys.stderr) + return sycl_home def _join_rocm_home(*paths) -> str: """ @@ -156,6 +171,20 @@ def _join_rocm_home(*paths) -> str: 'ROCm and Windows is not supported.') return os.path.join(ROCM_HOME, *paths) +def _join_sycl_home(*paths) -> str: + """ + Join paths with SYCL_HOME, or raises an error if it SYCL_HOME is not set. + + This is basically a lazy way of raising an error for missing $SYCL_HOME + only once we need to get any SYCL-specific path. + """ + if SYCL_HOME is None: + raise OSError('SYCL_HOME environment variable is not set. ' + 'Please set it to your OneAPI install root.') + + return os.path.join(SYCL_HOME, *paths) + + ABI_INCOMPATIBILITY_WARNING = ''' @@ -207,6 +236,8 @@ def _join_rocm_home(*paths) -> str: CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH') +SYCL_HOME = _find_sycl_home() if torch.xpu._is_compiled() else None + # PyTorch releases have the version pattern major.minor.patch, whereas when # PyTorch is built from source, we append the git commit hash, which gives # it the below pattern. @@ -1075,7 +1106,7 @@ def CUDAExtension(name, sources, *args, **kwargs): ... 'nvcc': ['-O2', '-rdc=true']}) """ library_dirs = kwargs.get('library_dirs', []) - library_dirs += library_paths(cuda=True) + library_dirs += library_paths(device_type="cuda") kwargs['library_dirs'] = library_dirs libraries = kwargs.get('libraries', []) @@ -1119,7 +1150,7 @@ def CUDAExtension(name, sources, *args, **kwargs): sources = list(hipified_sources) - include_dirs += include_paths(cuda=True) + include_dirs += include_paths(device_type="cuda") kwargs['include_dirs'] = include_dirs kwargs['language'] = 'c++' @@ -1144,9 +1175,9 @@ def CUDAExtension(name, sources, *args, **kwargs): return setuptools.Extension(name, sources, *args, **kwargs) -def include_paths(cuda: bool = False) -> List[str]: +def include_paths(device_type: str = "cpu") -> List[str]: """ - Get the include paths required to build a C++ or CUDA extension. + Get the include paths required to build a C++ or CUDA or SYCL extension. Args: cuda: If `True`, includes CUDA-specific include paths. @@ -1164,10 +1195,10 @@ def include_paths(cuda: bool = False) -> List[str]: os.path.join(lib_include, 'TH'), os.path.join(lib_include, 'THC') ] - if cuda and IS_HIP_EXTENSION: + if device_type == "cuda" and IS_HIP_EXTENSION: paths.append(os.path.join(lib_include, 'THH')) paths.append(_join_rocm_home('include')) - elif cuda: + elif device_type == "cuda": cuda_home_include = _join_cuda_home('include') # if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home. # but gcc doesn't like having /usr/include passed explicitly @@ -1180,10 +1211,12 @@ def include_paths(cuda: bool = False) -> List[str]: paths.append(cuda_inc_path) if CUDNN_HOME is not None: paths.append(os.path.join(CUDNN_HOME, 'include')) + elif device_type == "xpu": + paths.append(_join_sycl_home('include')) return paths -def library_paths(cuda: bool = False) -> List[str]: +def library_paths(device_type: str = "cpu") -> List[str]: """ Get the library paths required to build a C++ or CUDA extension. @@ -1196,12 +1229,12 @@ def library_paths(cuda: bool = False) -> List[str]: # We need to link against libtorch.so paths = [TORCH_LIB_PATH] - if cuda and IS_HIP_EXTENSION: + if device_type == "cuda" and IS_HIP_EXTENSION: lib_dir = 'lib' paths.append(_join_rocm_home(lib_dir)) if HIP_HOME is not None: paths.append(os.path.join(HIP_HOME, 'lib')) - elif cuda: + elif device_type == "cuda": if IS_WINDOWS: lib_dir = os.path.join('lib', 'x64') else: @@ -1216,6 +1249,17 @@ def library_paths(cuda: bool = False) -> List[str]: paths.append(_join_cuda_home(lib_dir)) if CUDNN_HOME is not None: paths.append(os.path.join(CUDNN_HOME, lib_dir)) + elif device_type == "xpu": + if IS_WINDOWS: + lib_dir = os.path.join('lib', 'x64') + else: + lib_dir = 'lib64' + if (not os.path.exists(_join_sycl_home(lib_dir)) and + os.path.exists(_join_sycl_home('lib'))): + lib_dir = 'lib' + + paths.append(_join_sycl_home(lib_dir)) + return paths @@ -2165,7 +2209,11 @@ def _write_ninja_file_to_build_library(path, user_includes = [os.path.abspath(file) for file in extra_include_paths] # include_paths() gives us the location of torch/extension.h - system_includes = include_paths(with_cuda) + # TODO generalize with_cuda as specific device type. + if with_cuda: + system_includes = include_paths("cuda") + else: + system_includes = include_paths("cpu") # sysconfig.get_path('include') gives us the location of Python.h # Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS # installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 91958127b03be8..63a6b817c42462 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import logging import weakref from typing import Set @@ -11,6 +12,9 @@ from torch.utils._pytree import tree_flatten +logger = logging.getLogger(__name__) + + __all__ = ["ModuleTracker"] @@ -93,9 +97,10 @@ def fn(*args): if is_bw: self._maybe_set_engine_callback() if name in self.parents: - print( - "The module hierarchy tracking seems to be messed up." - "Please file a bug to PyTorch." + logger.info( + "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s", + name, + "backward" if is_bw else "forward", ) self.parents.add(name) @@ -105,11 +110,11 @@ def _get_pop_fn(self, name, is_bw): def fn(*args): if name in self.parents: self.parents.remove(name) - elif not is_bw: - # Due to some input/output not requiring gradients, we cannot enforce - # proper nesting in backward - raise RuntimeError( - "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + else: + logger.info( + "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", + name, + "backward" if is_bw else "forward", ) return fn diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index 30984cfadf17fc..8acaf1696cb1f2 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import numpy as np +import numpy.typing as npt # Functions for converting @@ -21,7 +22,7 @@ def figure_to_image(figures, close=True): def render_to_rgb(figure): canvas = plt_backend_agg.FigureCanvasAgg(figure) canvas.draw() - data: np.ndarray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) + data: npt.NDArray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) w, h = figure.canvas.get_width_height() image_hwc = data.reshape([h, w, 4])[:, :, 0:3] image_chw = np.moveaxis(image_hwc, source=2, destination=0) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index cc272a7f26375a..f729ff06489ffe 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -263,7 +263,7 @@ def pop(self, key, *args): def setdefault(self, key, default=None): return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED - def update(self, dict=None, **kwargs): + def update(self, dict=None, **kwargs): # type: ignore[override] d = self.data if dict is not None: if not hasattr(dict, "items"): diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index b6eafc38ff1da2..a51924f89d21fe 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -388,19 +388,6 @@ def synchronize(device: _device_t = None) -> None: return torch._C._xpu_synchronize(device) -def empty_cache() -> None: - r"""Release all unoccupied cached memory currently held by the caching - allocator so that those can be used in other XPU application. - - .. note:: - :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU - memory available for PyTorch. However, it may help reduce fragmentation - of XPU memory in certain cases. - """ - if is_initialized(): - torch._C._xpu_emptyCache() - - def _get_generator(device: torch.device) -> torch._C.Generator: r"""Return the XPU Generator object for the given device. @@ -448,7 +435,29 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: return default_generator.get_offset() -from .random import * # noqa: F403 +# import here to avoid circular import +from .memory import ( + empty_cache, + max_memory_allocated, + max_memory_reserved, + memory_allocated, + memory_reserved, + memory_stats, + memory_stats_as_nested_dict, + reset_accumulated_memory_stats, + reset_peak_memory_stats, +) +from .random import ( + get_rng_state, + get_rng_state_all, + initial_seed, + manual_seed, + manual_seed_all, + seed, + seed_all, + set_rng_state, + set_rng_state_all, +) __all__ = [ @@ -475,6 +484,14 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: "is_initialized", "manual_seed", "manual_seed_all", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", "seed", "seed_all", "set_device", diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py new file mode 100644 index 00000000000000..32e7f6c2ce315f --- /dev/null +++ b/torch/xpu/memory.py @@ -0,0 +1,191 @@ +import collections +from typing import Any, Dict, Union + +import torch +from torch.types import Device + +from . import _get_device_index, is_initialized + + +_device_t = Union[Device, str, int, None] + + +def empty_cache() -> None: + r"""Release all unoccupied cached memory currently held by the caching + allocator so that those can be used in other XPU application. + + .. note:: + :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU + memory available for PyTorch. However, it may help reduce fragmentation + of XPU memory in certain cases. + """ + if is_initialized(): + torch._C._xpu_emptyCache() + + +def reset_peak_memory_stats(device: _device_t = None) -> None: + r"""Reset the "peak" stats tracked by the XPU memory allocator. + + See :func:`~torch.xpu.memory_stats` for details. Peak stats correspond to the + `"peak"` key in each individual stat dict. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_resetPeakMemoryStats(device) + + +def reset_accumulated_memory_stats(device: _device_t = None) -> None: + r"""Reset the "accumulated" (historical) stats tracked by the XPU memory allocator. + + See :func:`~torch.xpu.memory_stats` for details. Accumulated stats correspond to + the `"allocated"` and `"freed"` keys in each individual stat dict. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + device = _get_device_index(device, optional=True) + return torch._C._xpu_resetAccumulatedMemoryStats(device) + + +def memory_stats_as_nested_dict(device: _device_t = None) -> Dict[str, Any]: + r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} + device = _get_device_index(device, optional=True) + return torch._C._xpu_memoryStats(device) + + +def memory_stats(device: _device_t = None) -> Dict[str, Any]: + r"""Return a dictionary of XPU memory allocator statistics for a given device. + + The return value of this function is a dictionary of statistics, each of + which is a non-negative integer. + + Core statistics: + + - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of allocated memory. + - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of reserved memory. + - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + amount of active memory. + - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``: + memory requested by client code, compare this with allocated_bytes to check if + allocation rounding adds too much overhead. + + For these core statistics, values are broken down as follows. + + Pool type: + + - ``all``: combined statistics across all memory pools. + - ``large_pool``: statistics for the large allocation pool (for size >= 1MB allocations). + - ``small_pool``: statistics for the small allocation pool (for size < 1MB allocations). + + Metric type: + + - ``current``: current value of this metric. + - ``peak``: maximum value of this metric. + - ``allocated``: historical total increase in this metric. + - ``freed``: historical total decrease in this metric. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistics for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + result = [] + + def _recurse_add_to_result(prefix: str, obj: Any) -> None: + if isinstance(obj, dict): + if len(prefix) > 0: + prefix += "." + for k, v in obj.items(): + _recurse_add_to_result(prefix + k, v) + else: + result.append((prefix, obj)) + + stats = memory_stats_as_nested_dict(device=device) + _recurse_add_to_result("", stats) + result.sort() + + return collections.OrderedDict(result) + + +def memory_allocated(device: _device_t = None) -> int: + r"""Return the current GPU memory occupied by tensors in bytes for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + + .. note:: + This is likely less than the amount shown in `xpu-smi` since some + unused memory can be held by the caching allocator and some context + needs to be created on GPU. + """ + return memory_stats(device=device).get("allocated_bytes.all.current", 0) + + +def max_memory_allocated(device: _device_t = None) -> int: + r"""Return the maximum GPU memory occupied by tensors in bytes for a given device. + + By default, this returns the peak allocated memory since the beginning of + this program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to + reset the starting point in tracking this metric. For example, these two + functions can measure the peak allocated memory usage of each iteration in a + training loop. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + + +def memory_reserved(device: _device_t = None) -> int: + r"""Return the current GPU memory managed by the caching allocator in bytes for a given device. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("reserved_bytes.all.current", 0) + + +def max_memory_reserved(device: _device_t = None) -> int: + r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device. + + By default, this returns the peak cached memory since the beginning of this + program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to reset + the starting point in tracking this metric. For example, these two functions + can measure the peak cached memory amount of each iteration in a training + loop. + + Args: + device (torch.device or int or str, optional): selected device. Returns + statistic for the current device, given by :func:`~torch.xpu.current_device`, + if :attr:`device` is ``None`` (default). + """ + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) + + +__all__ = [ + "empty_cache", + "max_memory_allocated", + "max_memory_reserved", + "memory_allocated", + "memory_reserved", + "memory_stats", + "memory_stats_as_nested_dict", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", +] diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 19e5a11b8cdf2a..c657570ee3e249 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -94,6 +94,7 @@ def valuetype_type( t: Type, *, binds: ArgName, + mutable: bool = True, remove_non_owning_ref_types: bool = False, symint: bool = False, ) -> NamedCType | None: @@ -113,7 +114,7 @@ def valuetype_type( # All other BaseType currently map directly to BaseCppTypes. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) elif isinstance(t, OptionalType): - elem = valuetype_type(t.elem, binds=binds, symint=symint) + elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) if elem is None: return None return NamedCType(binds, OptionalCType(elem.type)) @@ -143,6 +144,7 @@ def argumenttype_type( r = valuetype_type( t, binds=binds, + mutable=mutable, symint=symint, remove_non_owning_ref_types=remove_non_owning_ref_types, ) @@ -231,7 +233,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: # placeholder is ignored # NB: symint is ALWAYS respected for return types. So symint argument # here is IGNORED - r = valuetype_type(t, binds="__placeholder__", symint=True) + r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) if r is not None: return r.type diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index a93d666114de64..93a72eb2b4a5c1 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -48,7 +48,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: # CompositeExplicitAutograd and the meta function (which could # hypothetically be SymInt), but for simplicity we plan for these to just # be handled in Python - r = cpp.valuetype_type(t, symint=False, binds=binds) + r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) if r is not None: return r diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index 0bdf28acfa6418..76cebcd0f0f1dc 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -5,7 +5,6 @@ from torchgen import local from torchgen.api.types import ( ArgName, - ArrayCType, BaseCType, Binding, ConstRefCType, @@ -88,7 +87,7 @@ def valuetype_type( if str(t.elem) == "bool": assert t.size is not None return NamedCType( - binds, ArrayCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]), t.size) + binds, ArrayRefCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool])) ) else: return None diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py index f206980af44fd7..6845e72a22a5d8 100644 --- a/torchgen/executorch/api/unboxing.py +++ b/torchgen/executorch/api/unboxing.py @@ -173,7 +173,16 @@ def _gen_code_list_type( # handle list type with size, e.g., bool[4] code.extend( f""" - auto {out_name} = {arg_name}.toBoolList(); +#ifdef USE_ATEN_LIB +std::array {out_name}; +auto {in_name} = {arg_name}.toBoolList(); +size_t _i = 0; +for (auto {elem_name}: {in_name}) {{ + {out_name}[_i++] = {elem_name}; +}} +#else +auto {out_name} = {arg_name}.toBoolList(); +#endif """.split( "\n" ) diff --git a/torchgen/gen.py b/torchgen/gen.py index 27ff0c48caa3cc..e5870a24fc6684 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -59,6 +59,7 @@ is_cuda_dispatch_key, is_generic_dispatch_key, is_ufunc_dispatch_key, + is_xpu_dispatch_key, Location, NativeFunction, NativeFunctionsGroup, @@ -184,7 +185,7 @@ def parse_native_yaml_struct( use_out_as_primary=True, external=False, # Only cuda-like devices in tree require device guards - device_guard=is_cuda_dispatch_key(k), + device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k), index=v, ) return ParsedYaml(rs, indices) diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 0e8d79cf679c2d..353302c7cd4a16 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -337,12 +337,11 @@ def compute_native_function_declaration( metadata_list = kernel_index.get_kernels(g).values() if metadata_list is None: return [] - prefix = "TORCH_API" # for kernels in lean mode, we declare two versions, one with context and one without. # In the end we will cleanup the unused one. def gen_decl(metadata: BackendMetadata, include_context: bool) -> str: - return f"{prefix} {sig.decl(name=metadata.kernel, include_context=include_context)};" + return f"{sig.decl(name=metadata.kernel, include_context=include_context)};" return [ gen_decl(metadata, include_context) @@ -499,11 +498,11 @@ def gen_headers( headers = { "headers": [ "#include // at::Tensor etc.", - "#include // TORCH_API", "#include ", ], } if use_aten_lib: + headers["headers"].append("#include // TORCH_API") cpu_fm.write( "NativeFunctions.h", lambda: dict( diff --git a/torchgen/model.py b/torchgen/model.py index 7459587e31d64a..956949343101ad 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -323,6 +323,18 @@ def is_cuda_dispatch_key(dk: DispatchKey) -> bool: } +# XPU specific dispatcy keys +def is_xpu_dispatch_key(dk: DispatchKey) -> bool: + return dk in { + DispatchKey.XPU, + DispatchKey.QuantizedXPU, + DispatchKey.SparseXPU, + DispatchKey.SparseCsrXPU, + DispatchKey.NestedTensorXPU, + DispatchKey.AutogradXPU, + } + + # Structured kernel generation is only supported for certain key types; # otherwise use old-style def is_structured_dispatch_key(dk: DispatchKey) -> bool: diff --git a/version.txt b/version.txt index b8feefb940f931..3d87ca93f8a9bf 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.5.0a0 +2.6.0a0