diff --git a/.github/actions/inductor-xpu-e2e-test/action.yml b/.github/actions/inductor-xpu-e2e-test/action.yml index 6e1dd4268..f4840f92b 100644 --- a/.github/actions/inductor-xpu-e2e-test/action.yml +++ b/.github/actions/inductor-xpu-e2e-test/action.yml @@ -41,7 +41,7 @@ runs: shell: bash run: | source activate e2e_ci - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh if [[ ${{ inputs.suite }} == *"torchbench"* ]]; then cd ../ && rm -rf audio && git clone --single-branch -b main https://github.com/pytorch/audio.git cd audio && git checkout $TORCHAUDIO_COMMIT_ID @@ -80,7 +80,7 @@ runs: source activate e2e_ci cp .github/scripts/inductor_xpu_test.sh ../pytorch cd ../pytorch - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh rm -f ${{ github.workspace }}/summary_accuracy.log # check param function contains() { @@ -198,7 +198,7 @@ runs: source activate e2e_ci cp .github/scripts/inductor_perf_summary.py ../pytorch cd ../pytorch - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh pip install styleFrame scipy pandas set -xe for suite in $(echo ${{ inputs.suite }} |sed 's/,/ /g') diff --git a/.github/workflows/_linux_ut.yml b/.github/workflows/_linux_ut.yml index 7cf2746c3..4abc9a95b 100644 --- a/.github/workflows/_linux_ut.yml +++ b/.github/workflows/_linux_ut.yml @@ -1,32 +1,37 @@ -name: inductor-xpu-ut-test +name: Linux UT Test on: workflow_call: inputs: - torch_xpu_ops_update: + pytorch: required: false type: string - default: 'true' - description: True means update xpu_ops when building pytorch, otherwise means not - ut_suite: + default: 'main' + description: Pytorch branch/commit + keep_torch_xpu_ops: + required: false + type: string + default: 'false' + description: Keep torch-xpu-ops pin. `true` means use pined commit + ut: required: true type: string - default: 'op_example,op_extended,op_ut,torch_xpu' - description: op_example,op_extended,op_ut,torch_xpu. Delimiter is comma - pytorch_branch: + default: '' + description: UT scope. `op_example,op_extended,op_ut,torch_xpu` Delimiter is comma + python: required: false type: string - default: 'main' - description: Set pytorch branch + default: '3.10' + description: Python version runner: required: true type: string default: 'linux.idc.xpu' - description: Set runner + description: Runner label jobs: - Inductor-XPU-UT-Tests: + Torch-XPU-UT-Tests: runs-on: ${{ inputs.runner }} timeout-minutes: 900 steps: @@ -36,60 +41,60 @@ jobs: run: | pwd cd ../ && rm -rf pytorch - git clone -b ${{ inputs.pytorch_branch }} https://github.com/pytorch/pytorch - cd pytorch && git log -n 1 && git submodule sync && git submodule update --init --recursive - if [ -z ${{ inputs.torch_xpu_ops_update }} ]; then - rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ + git clone https://github.com/pytorch/pytorch pytorch + cd pytorch && git checkout ${{ inputs.pytorch }} + # apply PRs for stock pytorch + pip install requests + python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py + git status && git show -s + git submodule sync && git submodule update --init --recursive + if [[ ${{ inputs.keep_torch_xpu_ops }} == 'true' ]]; then + echo "Don't replace torch-xpu-ops!" else - if [[ ${{ inputs.torch_xpu_ops_update }} == 'true' ]]; then - rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ - else - echo "Not update torch-xpu-ops" - fi + rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ + # Workaround for torch-xpu-ops ci test + sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt fi - # Workaround for torch-xpu-ops ci test - sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt - name: Build Pytorch XPU run: | which conda && conda clean -ay conda remove --all -y -n xpu_op_${ZE_AFFINITY_MASK} || \ rm -rf $(dirname ${CONDA_EXE})/../envs/xpu_op_${ZE_AFFINITY_MASK} - conda create -n xpu_op_${ZE_AFFINITY_MASK} python=3.10 cmake ninja -y + conda create -n xpu_op_${ZE_AFFINITY_MASK} python=${{ inputs.python }} cmake ninja -y source activate xpu_op_${ZE_AFFINITY_MASK} conda install -c intel mkl-static mkl-include -y cd ../pytorch pip install -r requirements.txt export USE_XPU=1 - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py bdist_wheel pip install --force-reinstall dist/*.whl git clone https://github.com/pytorch/vision && cd vision && python setup.py install && cd .. pip install -r .ci/docker/requirements-ci.txt - name: Run XPU OP Examples - if: contains(inputs.ut_suite, 'op_example') + if: contains(inputs.ut, 'op_example') || github.event_name == 'schedule' run: | cd ${{ github.workspace }} - mkdir -p ut_log xpu-smi discovery - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh source activate xpu_op_${ZE_AFFINITY_MASK} cd ${{ github.workspace }} cd examples pip install pytest timeout 8000 pytest -v - name: Run XPU OP Extended UT - if: contains(inputs.ut_suite, 'op_extended') + if: contains(inputs.ut, 'op_extended') || github.event_name == 'schedule' run: | - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh source activate xpu_op_${ZE_AFFINITY_MASK} export PYTORCH_TEST_WITH_SLOW=1 cd ../pytorch/third_party/torch-xpu-ops/test/xpu/extended/ timeout 10000 python run_test_with_skip.py - name: Run XPU OP UT - if: contains(inputs.ut_suite, 'op_ut') + if: contains(inputs.ut, 'op_ut') || github.event_name == 'schedule' run: | - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh source activate xpu_op_${ZE_AFFINITY_MASK} export PYTORCH_ENABLE_XPU_FALLBACK=1 export PYTORCH_TEST_WITH_SLOW=1 @@ -101,9 +106,9 @@ jobs: # test_foreach, test_decomp timeout 10000 python run_test_with_only.py - name: Run Torch XPU UT - if: contains(inputs.ut_suite, 'torch_xpu') + if: contains(inputs.ut, 'torch_xpu') || github.event_name == 'schedule' run: | - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh source activate xpu_op_${ZE_AFFINITY_MASK} cd ../pytorch TEST_REPORTS_DIR=$(pwd)/test/test-reports @@ -117,7 +122,21 @@ jobs: fi done # Run Pytorch XPU python UT - export PYTORCH_ENABLE_XPU_FALLBACK=1 - sed -i 's/selected_tests = exclude_tests(XPU_BLOCKLIST.*/selected_tests = XPU_TEST/g' ./test/run_test.py - python test/run_test.py --xpu + export PYTORCH_TEST_WITH_SLOW=1 + export PYTORCH_TESTING_DEVICE_ONLY_FOR="xpu" + test_cmd="python test/run_test.py --include " + # All Inductor UT under test/inductor + for test in $(ls test/inductor | grep test); + do + test_cmd="${test_cmd} inductor/$test"; + done + # All xpu ut under test/xpu + for test in $(ls test/xpu | grep test); + do + test_cmd="${test_cmd} xpu/$test"; + done + if [ -f "test/test_xpu.py" ]; then + test_cmd="${test_cmd} test_xpu.py" + fi + eval $test_cmd diff --git a/.github/workflows/inductor_xpu_e2e_ci.yml b/.github/workflows/inductor_xpu_e2e_ci.yml deleted file mode 100644 index c7d408b33..000000000 --- a/.github/workflows/inductor_xpu_e2e_ci.yml +++ /dev/null @@ -1,137 +0,0 @@ -name: E2E CI Tests - -on: - workflow_dispatch: - pull_request: - types: - - opened - - synchronize - - reopened - - converted_to_draft - - ready_for_review - branches: - - main - - release/* - -permissions: read-all - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - Inductor-XPU-E2E-CI-Tests: - runs-on: pvc_e2e - # Don't run on forked repos and draft PRs - if: ${{ (github.repository_owner == 'intel') && (github.event.pull_request.draft == false) }} - timeout-minutes: 900 - steps: - - name: Checkout torch-xpu-ops - uses: actions/checkout@v4 - - name: Prepare Conda ENV - run: | - which conda && conda clean -ay - conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci - conda create -n e2e_ci python=3.10 cmake ninja -y - source activate e2e_ci - conda install -c intel mkl-static mkl-include -y - pip install pandas scipy tqdm - - name: Prepare Stock Pytorch - run: | - pwd - cd ../ && rm -rf pytorch - source activate e2e_ci - git clone -b main https://github.com/pytorch/pytorch pytorch - cd pytorch - # apply PRs for stock pytorch - pip install requests - # https://github.com/mengfei25/pytorch/pull/18 internal use only for subset model list - python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py -e https://github.com/mengfei25/pytorch/pull/18 - git status && git show -s - git submodule sync && git submodule update --init --recursive - rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ - # Workaround for torch-xpu-ops ci test - sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt - - name: Triton Installation - run: | - source activate e2e_ci - cd ../pytorch - TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" - TRITON_PINNED_COMMIT=$(cat .ci/docker/ci_commit_pins/triton-xpu.txt) - echo ${TRITON_REPO}@${TRITON_PINNED_COMMIT} - pip install --force-reinstall "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python" - - name: Build Pytorch XPU - run: | - source activate e2e_ci - cd ../pytorch - pip install -r requirements.txt - export USE_XPU=1 - source /opt/intel/oneapi/compiler/latest/env/vars.sh - export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} - python setup.py bdist_wheel - pip install --force-reinstall dist/*.whl - - name: Identify pinned versions - run: | - cd ../pytorch - echo "TRITON_COMMIT_ID=$(<.ci/docker/ci_commit_pins/triton-xpu.txt)" >> "${GITHUB_ENV}" - echo "TORCHVISION_COMMIT_ID=$(<.github/ci_commit_pins/vision.txt)" >> "${GITHUB_ENV}" - echo "TORCHTEXT_COMMIT_ID=$(<.github/ci_commit_pins/text.txt)" >> "${GITHUB_ENV}" - echo "TORCHAUDIO_COMMIT_ID=$(<.github/ci_commit_pins/audio.txt)" >> "${GITHUB_ENV}" - echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" >> "${GITHUB_ENV}" - echo "TIMM_COMMIT_ID=$(<.ci/docker/ci_commit_pins/timm.txt)" >> "${GITHUB_ENV}" - - name: Show GITHUB_ENV - run: | - echo "$GITHUB_ENV" - rm -rf ../pytorch/inductor_log - rm -rf /tmp/torchinductor_* - - name: Huggingface BF16 Training Accuracy Test - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: bfloat16 - mode: training - scenario: accuracy - env_prepare: true - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Huggingface FP16 Training Accuracy Test - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: float16 - mode: training - scenario: accuracy - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Timm_models BF16 Training Accuracy Test - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: timm_models - dt: bfloat16 - mode: training - scenario: accuracy - env_prepare: true - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Torchbench BF16 Training Accuracy Test - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: torchbench - dt: bfloat16 - mode: training - scenario: accuracy - env_prepare: true - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Summarize archieve files - if: always() - run: | - rm -rf ${{ github.workspace }}/upload_files - cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files - failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true) - if [ ${failed_case} -ne 0 ];then - grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log - exit 1 - fi - - name: Upload Inductor XPU E2E Data - if: always() - uses: actions/upload-artifact@v4 - with: - name: Inductor-XPU-E2E-Data-${{ github.event.pull_request.number || github.sha }} - path: ${{ github.workspace }}/upload_files diff --git a/.github/workflows/inductor_xpu_e2e_nightly.yml b/.github/workflows/nightly_ondemand.yml similarity index 68% rename from .github/workflows/inductor_xpu_e2e_nightly.yml rename to .github/workflows/nightly_ondemand.yml index a8d316580..039407bc8 100644 --- a/.github/workflows/inductor_xpu_e2e_nightly.yml +++ b/.github/workflows/nightly_ondemand.yml @@ -1,4 +1,4 @@ -name: E2E Nightly_OnDemand Tests +name: Nightly-OnDemand Tests on: schedule: @@ -6,57 +6,56 @@ on: - cron: '0 13 * * *' workflow_dispatch: inputs: - python: + pytorch: required: false type: string - default: '3.10' - description: Specify python version + default: 'main' + description: Pytorch branch/commit + keep_torch_xpu_ops: + required: false + type: string + default: 'false' + description: Keep torch-xpu-ops pin. `true` means use pined commit + ut: + required: true + type: string + default: 'torch_xpu' + description: UT scope. `op_example,op_extended,op_ut,torch_xpu` Delimiter is comma triton: required: false type: string default: '' - description: Specify triton commit, use pytorch pined commit by default + description: Triton commit. Use pytorch pined commit by default suite: required: true type: string default: 'huggingface' - description: Dynamo benchmarks test suite. huggingface,timm_models,torchbench. Delimiter is comma + description: Dynamo benchmarks test suite. `huggingface,timm_models,torchbench`. Delimiter is comma dt: required: true type: string default: 'float32' - description: Data precision of the test.float32,bfloat16,float16,amp_bf16,amp_fp16. Delimiter is comma + description: Data precision of the test. `float32,bfloat16,float16,amp_bf16,amp_fp16`. Delimiter is comma mode: required: true type: string default: 'inference' - description: inference,training. Delimiter is comma + description: Test mode. `inference,training`. Delimiter is comma scenario: required: true type: string default: 'accuracy' - description: accuracy,performance. Delimiter is comma + description: Test scenario. `accuracy,performance`. Delimiter is comma model: required: false type: string default: '' - description: If set, will only launch this one - torch_xpu_ops_update: - required: false - type: string - default: 'true' - description: True means update xpu_ops when building pytorch, otherwise means not - ut_suite: - required: true - type: string - default: 'op_example,op_extended,op_ut,torch_xpu' - description: op_example,op_extended,op_ut,torch_xpu. Delimiter is comma - pytorch_branch: + description: Model. Will only run this one mode if set + python: required: false type: string - default: 'main' - description: Set pytorch branch - + default: '3.10' + description: Python version permissions: read-all @@ -65,11 +64,26 @@ concurrency: cancel-in-progress: true jobs: - Inductor-XPU-E2E-Nightly-Tests: + Linux-Nightly-Ondemand-UT-Tests: + if: github.event_name == 'schedule' || ${{ inputs.ut_suite }} + uses: ./.github/workflows/_linux_ut.yml + with: + keep_torch_xpu_ops: ${{ github.event_name == 'schedule' && 'false' || inputs.keep_torch_xpu_ops }} + ut: ${{ github.event_name == 'schedule' && 'op_example,op_extended,op_ut,torch_xpu' || inputs.ut }} + pytorch: ${{ github.event_name == 'schedule' && 'main' || inputs.pytorch }} + python: ${{ github.event_name == 'schedule' && '3.10' || inputs.python }} + runner: linux.idc.xpu + + Linux-Nightly-Ondemand-E2E-Tests: runs-on: pvc_e2e # Don't run on forked repos if: github.repository_owner == 'intel' timeout-minutes: 900 + env: + pytorch: ${{ github.event_name == 'schedule' && 'main' || inputs.pytorch }} + keep_torch_xpu_ops: ${{ github.event_name == 'schedule' && 'false' || inputs.keep_torch_xpu_ops }} + ut: ${{ github.event_name == 'schedule' && 'op_example,op_extended,op_ut,torch_xpu' || inputs.ut }} + python: ${{ github.event_name == 'schedule' && '3.10' || inputs.python }} outputs: TORCH_BRANCH_ID: ${{ steps.pinned.outputs.TORCH_BRANCH_ID }} TORCH_COMMIT_ID: ${{ steps.pinned.outputs.TORCH_COMMIT_ID }} @@ -80,7 +94,6 @@ jobs: TORCHBENCH_COMMIT_ID: ${{ steps.pinned.outputs.TORCHBENCH_COMMIT_ID }} TORCHVISION_COMMIT_ID: ${{ steps.pinned.outputs.TORCHVISION_COMMIT_ID }} TORCHAUDIO_COMMIT_ID: ${{ steps.pinned.outputs.TORCHAUDIO_COMMIT_ID }} - # TORCHTEXT_COMMIT_ID: ${{ steps.pinned.outputs.TORCHTEXT_COMMIT_ID }} TRANSFORMERS_VERSION: ${{ steps.pinned.outputs.TRANSFORMERS_VERSION }} TIMM_COMMIT_ID: ${{ steps.pinned.outputs.TIMM_COMMIT_ID }} TRITON_COMMIT_ID: ${{ steps.pinned.outputs.TRITON_COMMIT_ID }} @@ -91,7 +104,7 @@ jobs: run: | which conda && conda clean -ay conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci - conda create -n e2e_ci python=${{ inputs.python }} cmake ninja -y + conda create -n e2e_ci python=${{ env.python }} cmake ninja -y source activate e2e_ci conda install -c intel mkl-static mkl-include -y pip install pandas scipy tqdm @@ -100,16 +113,20 @@ jobs: pwd cd ../ && rm -rf pytorch source activate e2e_ci - git clone -b main https://github.com/pytorch/pytorch pytorch - cd pytorch + git clone https://github.com/pytorch/pytorch pytorch + cd pytorch && git checkout ${{ env.pytorch }} # apply PRs for stock pytorch pip install requests python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py git status && git show -s git submodule sync && git submodule update --init --recursive - rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ - # Workaround for torch-xpu-ops ci test - sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt + if [[ ${{ env.keep_torch_xpu_ops }} == 'true' ]]; then + echo "Don't replace torch-xpu-ops!" + else + rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ + # Workaround for torch-xpu-ops ci test + sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt + fi - name: Identify pinned versions id: pinned run: | @@ -128,7 +145,7 @@ jobs: echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "TIMM_COMMIT_ID=$(<.ci/docker/ci_commit_pins/timm.txt)" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "MODEL_ONLY_NAME=${{ inputs.model }}" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh echo "DRIVER_VERSION=$(dkms status 2>&1 |grep 'intel-i915-dkms' |sed 's/.*\///;s/,.*//')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" echo "BUNDLE_VERSION=$(dpcpp --version 2>&1 |grep 'DPC++/C++' |sed 's/.*(//;s/).*//')" |tee -a "${GITHUB_OUTPUT}" >> "${GITHUB_ENV}" . /etc/os-release @@ -148,7 +165,7 @@ jobs: cd ../pytorch pip install -r requirements.txt export USE_XPU=1 - source /opt/intel/oneapi/compiler/latest/env/vars.sh + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} python setup.py bdist_wheel pip install --force-reinstall dist/*.whl @@ -157,63 +174,18 @@ jobs: echo "$GITHUB_ENV" rm -rf ../pytorch/inductor_log rm -rf /tmp/torchinductor_* - - name: Nightly Huggingface FP32 Inference Accuracy Test - if: ${{ !inputs.suite }} + - name: Nightly Huggingface FP32/BF16/FP16 Inference & Training Accuracy Test + if: github.event_name == 'schedule' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: huggingface env_prepare: true - dt: float32 - mode: inference - scenario: accuracy - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Nightly Huggingface BF16 Inference Accuracy Test - if: ${{ !inputs.suite }} - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: bfloat16 - mode: inference - scenario: accuracy - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Nightly Huggingface FP16 Inference Accuracy Test - if: ${{ !inputs.suite }} - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: float16 - mode: inference - scenario: accuracy - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Nightly Huggingface FP32 Training Accuracy Test - if: ${{ !inputs.suite }} - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: float32 - mode: training - scenario: accuracy - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Nightly Huggingface BF16 Training Accuracy Test - if: ${{ !inputs.suite }} - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: bfloat16 - mode: training - scenario: accuracy - hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - - name: Nightly Huggingface FP16 Training Accuracy Test - if: ${{ !inputs.suite }} - uses: ./.github/actions/inductor-xpu-e2e-test - with: - suite: huggingface - dt: float16 - mode: training + dt: float32,bfloat16,float16 + mode: inference,traning scenario: accuracy hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: Nightly Torchbench BF16 Training Accuracy Test - if: ${{ !inputs.suite }} + if: github.event_name == 'schedule' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: torchbench @@ -223,7 +195,7 @@ jobs: env_prepare: true hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: Nightly Timm_models FP16 Training Accuracy Test - if: ${{ !inputs.suite }} + if: github.event_name == 'schedule' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: timm_models @@ -233,7 +205,7 @@ jobs: env_prepare: true hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: OnDemand Test (${{ inputs.suite }} ${{ inputs.dt }} ${{ inputs.mode }} ${{ inputs.scenario }}) - if: ${{ inputs.suite }} + if: github.event_name != 'schedule' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: ${{ inputs.suite }} @@ -258,15 +230,6 @@ jobs: with: name: Inductor-XPU-E2E-Data-${{ github.event.pull_request.number || github.sha }} path: ${{ github.workspace }}/upload_files - - Inductor-XPU-UT-Nightly-Tests: - if: ${{ inputs.ut_suite }} - name: Nightly Inductor XPU UT Test - uses: ./.github/workflows/_linux_ut.yml - with: - ut_suite: ${{ inputs.ut_suite }} - pytorch_branch: ${{ inputs.pytorch_branch }} - runner: linux.idc.xpu Tests-Failure-And-Report: if: always() @@ -275,7 +238,8 @@ jobs: issues: write env: GH_TOKEN: ${{ github.token }} - needs: Inductor-XPU-E2E-Nightly-Tests + python: ${{ github.event_name == 'schedule' && '3.10' || inputs.python }} + needs: Linux-Nightly-Ondemand-E2E-Tests steps: - name: Report github issue for XPU OPS nightly if: github.repository_owner == 'intel' @@ -284,23 +248,23 @@ jobs: # Test env build_url="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" repo="${{ github.repository }}" - TORCH_BRANCH_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_BRANCH_ID }}" - TORCH_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_COMMIT_ID }}" - DRIVER_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.DRIVER_VERSION }}" - BUNDLE_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.BUNDLE_VERSION }}" - OS_PRETTY_NAME="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.OS_PRETTY_NAME }}" - GCC_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.GCC_VERSION }}" - TORCHBENCH_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHBENCH_COMMIT_ID }}" - TORCHVISION_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHVISION_COMMIT_ID }}" - TORCHAUDIO_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHAUDIO_COMMIT_ID }}" - # TORCHTEXT_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHTEXT_COMMIT_ID }}" - TRANSFORMERS_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRANSFORMERS_VERSION }}" - TIMM_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TIMM_COMMIT_ID }}" - TRITON_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRITON_COMMIT_ID }}" + TORCH_BRANCH_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TORCH_BRANCH_ID }}" + TORCH_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TORCH_COMMIT_ID }}" + DRIVER_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.DRIVER_VERSION }}" + BUNDLE_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.BUNDLE_VERSION }}" + OS_PRETTY_NAME="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.OS_PRETTY_NAME }}" + GCC_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.GCC_VERSION }}" + TORCHBENCH_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TORCHBENCH_COMMIT_ID }}" + TORCHVISION_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TORCHVISION_COMMIT_ID }}" + TORCHAUDIO_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TORCHAUDIO_COMMIT_ID }}" + # TORCHTEXT_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TORCHTEXT_COMMIT_ID }}" + TRANSFORMERS_VERSION="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TRANSFORMERS_VERSION }}" + TIMM_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TIMM_COMMIT_ID }}" + TRITON_COMMIT_ID="${{ needs.Linux-Nightly-Ondemand-E2E-Tests.outputs.TRITON_COMMIT_ID }}" # Test status - if [ "${{ needs.Inductor-XPU-E2E-Nightly-Tests.result }}" == "success" ];then + if [ "${{ needs.Linux-Nightly-Ondemand-E2E-Tests.result }}" == "success" ];then test_status=Success - elif [ "${{ needs.Inductor-XPU-E2E-Nightly-Tests.result }}" == "failure" ];then + elif [ "${{ needs.Linux-Nightly-Ondemand-E2E-Tests.result }}" == "failure" ];then test_status=Failure cc_comment="CC ${{ secrets.NIGHTLY_EMAIL_LIST }}" else @@ -317,7 +281,7 @@ jobs: test_issue_id=432 fi # Test report - echo -e "$cc_comment\n**${test_status}** $test_type Test on $(date +'%F'), See: $build_url\n" > ${{ github.workspace }}/report.txt + echo -e "**${test_status}** $test_type Test on $(date +'%F'), See: $build_url\n" > ${{ github.workspace }}/report.txt printf "Torch-xpu-ops | PyTorch | Triton\n--- | --- | ---\n${GITHUB_WORKFLOW_SHA:0:7} on ${GITHUB_REF_NAME} | " >> ${{ github.workspace }}/report.txt printf "[${TORCH_COMMIT_ID:0:7}](https://github.com/pytorch/pytorch/commit/${TORCH_COMMIT_ID:0:7}) on $TORCH_BRANCH_ID | " >> ${{ github.workspace }}/report.txt echo -e "[${TRITON_COMMIT_ID:0:7}](https://github.com/intel/intel-xpu-backend-for-triton/commit/${TRITON_COMMIT_ID:0:7}) \n" >> ${{ github.workspace }}/report.txt @@ -328,7 +292,7 @@ jobs: printf "[${TORCHVISION_COMMIT_ID:0:7}](https://github.com/pytorch/vision/commit/${TORCHVISION_COMMIT_ID:0:7}) | " >> ${{ github.workspace }}/report.txt echo -e "[${TORCHAUDIO_COMMIT_ID:0:7}](https://github.com/pytorch/audio/commit/${TORCHAUDIO_COMMIT_ID:0:7}) \n" >> ${{ github.workspace }}/report.txt printf "Device | OS | GCC | Python | Driver(DKMS) | Bundle(DPCPP)\n--- | --- | --- | --- | --- | ---\n" >> ${{ github.workspace }}/report.txt - echo -e "$RUNNER_NAME | $OS_PRETTY_NAME | $GCC_VERSION | ${{ inputs.python }} | $DRIVER_VERSION| $BUNDLE_VERSION \n" >> ${{ github.workspace }}/report.txt + echo -e "$RUNNER_NAME | $OS_PRETTY_NAME | $GCC_VERSION | ${{ env.python }} | $DRIVER_VERSION| $BUNDLE_VERSION \n" >> ${{ github.workspace }}/report.txt if [ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" ];then test_scope="${{ inputs.suite }}/${{ inputs.dt }}/${{ inputs.mode }}/${{ inputs.scenario }}" if [ "${{ inputs.triton }}" != "" ];then @@ -339,6 +303,7 @@ jobs: fi echo -e "Inputs | $test_scope\n--- | --- \n" >> ${{ github.workspace }}/report.txt fi + echo "$cc_comment\n" >> ${{ github.workspace }}/report.txt # Report report_txt=$(cat ${{ github.workspace }}/report.txt) gh --repo $repo issue comment $test_issue_id --body "$report_txt" diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 1bd635d1a..3b103e4d9 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -22,9 +22,123 @@ jobs: preci-ut: # Don't run on forked repos and draft PRs if: ${{ (github.repository_owner == 'intel') && (github.event.pull_request.draft == false) }} - name: preci-ut uses: ./.github/workflows/_linux_ut.yml with: - ut_suite: op_example,op_extended,op_ut,torch_xpu + ut: op_example,op_extended,op_ut runner: linux.idc.xpu - \ No newline at end of file + + Inductor-XPU-E2E-CI-Tests: + runs-on: pvc_e2e + # Don't run on forked repos and draft PRs + if: ${{ (github.repository_owner == 'intel') && (github.event.pull_request.draft == false) }} + timeout-minutes: 900 + steps: + - name: Checkout torch-xpu-ops + uses: actions/checkout@v4 + - name: Prepare Conda ENV + run: | + which conda && conda clean -ay + conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci + conda create -n e2e_ci python=3.10 cmake ninja -y + source activate e2e_ci + conda install -c intel mkl-static mkl-include -y + pip install pandas scipy tqdm + - name: Prepare Stock Pytorch + run: | + pwd + cd ../ && rm -rf pytorch + source activate e2e_ci + git clone -b main https://github.com/pytorch/pytorch pytorch + cd pytorch + # apply PRs for stock pytorch + pip install requests + # https://github.com/mengfei25/pytorch/pull/18 internal use only for subset model list + python ../torch-xpu-ops/.github/scripts/apply_torch_pr.py -e https://github.com/mengfei25/pytorch/pull/18 + git status && git show -s + git submodule sync && git submodule update --init --recursive + rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ + # Workaround for torch-xpu-ops ci test + sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt + - name: Triton Installation + run: | + source activate e2e_ci + cd ../pytorch + TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton" + TRITON_PINNED_COMMIT=$(cat .ci/docker/ci_commit_pins/triton-xpu.txt) + echo ${TRITON_REPO}@${TRITON_PINNED_COMMIT} + pip install --force-reinstall "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python" + - name: Build Pytorch XPU + run: | + source activate e2e_ci + cd ../pytorch + pip install -r requirements.txt + export USE_XPU=1 + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh + export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} + python setup.py bdist_wheel + pip install --force-reinstall dist/*.whl + - name: Identify pinned versions + run: | + cd ../pytorch + echo "TRITON_COMMIT_ID=$(<.ci/docker/ci_commit_pins/triton-xpu.txt)" >> "${GITHUB_ENV}" + echo "TORCHVISION_COMMIT_ID=$(<.github/ci_commit_pins/vision.txt)" >> "${GITHUB_ENV}" + echo "TORCHTEXT_COMMIT_ID=$(<.github/ci_commit_pins/text.txt)" >> "${GITHUB_ENV}" + echo "TORCHAUDIO_COMMIT_ID=$(<.github/ci_commit_pins/audio.txt)" >> "${GITHUB_ENV}" + echo "TRANSFORMERS_VERSION=$(<.ci/docker/ci_commit_pins/huggingface.txt)" >> "${GITHUB_ENV}" + echo "TIMM_COMMIT_ID=$(<.ci/docker/ci_commit_pins/timm.txt)" >> "${GITHUB_ENV}" + - name: Show GITHUB_ENV + run: | + echo "$GITHUB_ENV" + rm -rf ../pytorch/inductor_log + rm -rf /tmp/torchinductor_* + - name: Huggingface BF16 Training Accuracy Test + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: huggingface + dt: bfloat16 + mode: training + scenario: accuracy + env_prepare: true + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Huggingface FP16 Training Accuracy Test + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: huggingface + dt: float16 + mode: training + scenario: accuracy + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Timm_models BF16 Training Accuracy Test + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: timm_models + dt: bfloat16 + mode: training + scenario: accuracy + env_prepare: true + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Torchbench BF16 Training Accuracy Test + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: torchbench + dt: bfloat16 + mode: training + scenario: accuracy + env_prepare: true + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Summarize archieve files + if: always() + run: | + rm -rf ${{ github.workspace }}/upload_files + cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files + failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true) + if [ ${failed_case} -ne 0 ];then + grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log + exit 1 + fi + - name: Upload Inductor XPU E2E Data + if: always() + uses: actions/upload-artifact@v4 + with: + name: Inductor-XPU-E2E-Data-${{ github.event.pull_request.number || github.sha }} + path: ${{ github.workspace }}/upload_files diff --git a/src/ATen/native/xpu/PointwiseOps.cpp b/src/ATen/native/xpu/PointwiseOps.cpp index 210cec3e6..a01bdc391 100644 --- a/src/ATen/native/xpu/PointwiseOps.cpp +++ b/src/ATen/native/xpu/PointwiseOps.cpp @@ -6,6 +6,63 @@ namespace at { +TensorIterator addcdiv_meta( + const Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value, + Tensor& out) { + if (isIntegralType(tensor1.scalar_type(), /*includeBool=*/true) && + isIntegralType(tensor2.scalar_type(), /*includeBool=*/true)) { + TORCH_CHECK( + false, + "Integer division with addcdiv is no longer supported, and in a future ", + "release addcdiv will perform a true division of tensor1 and tensor2. ", + "The historic addcdiv behavior can be implemented as ", + "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ", + "for integer inputs and as ", + "(input + value * tensor1 / tensor2) for float inputs. ", + "The future addcdiv behavior is just the latter implementation: ", + "(input + value * tensor1 / tensor2), for all dtypes."); + } + + TensorIterator iter; + iter.build_ternary_op(out, self, tensor1, tensor2); + return iter; +} + +Tensor& XPUNativeFunctions::addcdiv_out( + const Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value, + Tensor& out) { + auto iter = addcdiv_meta(self, tensor1, tensor2, value, out); + native::xpu::addcdiv_kernel(iter, value); + return out; +} + +Tensor XPUNativeFunctions::addcdiv( + const Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value) { + Tensor out; + auto iter = addcdiv_meta(self, tensor1, tensor2, value, out); + native::xpu::addcdiv_kernel(iter, value); + return iter.output(); +} + +Tensor& XPUNativeFunctions::addcdiv_( + Tensor& self, + const Tensor& tensor1, + const Tensor& tensor2, + const Scalar& value) { + auto iter = addcdiv_meta(self, tensor1, tensor2, value, self); + native::xpu::addcdiv_kernel(iter, value); + return self; +} + TensorIterator addcmul_meta( const Tensor& self, const Tensor& tensor1, diff --git a/src/ATen/native/xpu/TensorProperties.cpp b/src/ATen/native/xpu/TensorProperties.cpp new file mode 100644 index 000000000..428d18fcd --- /dev/null +++ b/src/ATen/native/xpu/TensorProperties.cpp @@ -0,0 +1,16 @@ +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace at { + +bool XPUNativeFunctions::is_set_to(const Tensor& self, const Tensor& src) { + return at::native::is_set_to(self, src); +} + +} // namespace at diff --git a/src/ATen/native/xpu/UnaryOps.cpp b/src/ATen/native/xpu/UnaryOps.cpp index 63ba1fd3c..1d3137f79 100644 --- a/src/ATen/native/xpu/UnaryOps.cpp +++ b/src/ATen/native/xpu/UnaryOps.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -630,6 +631,18 @@ Tensor& XPUNativeFunctions::cosh_out(const Tensor& self, Tensor& out) { return out; } +Tensor& XPUNativeFunctions::conj_physical_out(const Tensor& self, Tensor& out) { + auto iter = TensorIterator::unary_op(out, self); + native::xpu::conj_physical_kernel(iter); + return out; +} + +Tensor& XPUNativeFunctions::conj_physical_(Tensor& self) { + if (!self.is_complex()) + return self; + return XPUNativeFunctions::conj_physical_out(self, self); +} + TensorIterator ceil_meta(const Tensor& self, Tensor& out) { TORCH_CHECK(!self.is_complex(), "ceil is not supported for complex inputs"); TensorIterator iter; diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 8e66c0fea..f320725df 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -163,7 +163,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "adaptive_max_pool2d.out", "adaptive_max_pool3d_backward.grad_input", "adaptive_max_pool3d.out", - "addcdiv.out", "aminmax.out", "angle", "argmin.out", @@ -179,7 +178,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "cholesky", "cholesky_inverse", "_cholesky_solve_helper", - "conj_physical.out", "copysign.out", "count_nonzero.dim_IntList", "_ctc_loss", diff --git a/src/ATen/native/xpu/sycl/Col2ImKernel.cpp b/src/ATen/native/xpu/sycl/Col2ImKernel.cpp index a0ba012c1..299711162 100644 --- a/src/ATen/native/xpu/sycl/Col2ImKernel.cpp +++ b/src/ATen/native/xpu/sycl/Col2ImKernel.cpp @@ -200,7 +200,7 @@ void col2im_kernel( bool batched_input = true; if (input.dim() == 2) { batched_input = false; - input.resize_({1, input.size(0), input.size(1)}); + input = input.view({1, input.size(0), input.size(1)}); } auto batch_size = input.size(0); diff --git a/src/ATen/native/xpu/sycl/Im2ColKernel.cpp b/src/ATen/native/xpu/sycl/Im2ColKernel.cpp index 149665bc7..aa511e6df 100644 --- a/src/ATen/native/xpu/sycl/Im2ColKernel.cpp +++ b/src/ATen/native/xpu/sycl/Im2ColKernel.cpp @@ -187,7 +187,7 @@ void im2col_kernel( if (input.dim() == 3) { batched_input = false; - input.resize_({1, input.size(0), input.size(1), input.size(2)}); + input = input.view({1, input.size(0), input.size(1), input.size(2)}); } auto batch_size = input.size(0); diff --git a/src/ATen/native/xpu/sycl/IndexingUtils.h b/src/ATen/native/xpu/sycl/IndexingUtils.h index 1c6d9c373..26eb2f1ea 100644 --- a/src/ATen/native/xpu/sycl/IndexingUtils.h +++ b/src/ATen/native/xpu/sycl/IndexingUtils.h @@ -99,7 +99,7 @@ static std::tuple computeLinearIndex( static std:: tuple> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { - checkIndexTensorTypes(orig, /*allow_int*/ true); + checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more // LongTensors auto indices = expandTensors(self, orig); diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp index 7b00d09e3..d38f511d7 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp @@ -1,6 +1,6 @@ #include +#include #include -#include #include #include @@ -8,31 +8,98 @@ namespace at::native::xpu { template -struct AddcmulKernelFunctor { - using opmath_t = at::opmath_type; +struct AddcmulFunctor { + using accscalar_t = at::acc_type; scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { - return static_cast(a) + - alpha_ * static_cast(b) * static_cast(c); + return static_cast(a) + + alpha_ * static_cast(b) * static_cast(c); } - AddcmulKernelFunctor(opmath_t alpha) : alpha_(alpha) {} + AddcmulFunctor(accscalar_t alpha) : alpha_(alpha) {} private: - opmath_t alpha_; + accscalar_t alpha_; +}; + +template +struct AddcmulComplexFunctor { + scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { + return a + alpha_ * b * c; + } + + AddcmulComplexFunctor(scalar_t alpha) : alpha_(alpha) {} + + private: + scalar_t alpha_; }; void addcmul_kernel(TensorIterator& iter, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "addcmul_xpu", - [&]() { - using opmath_t = at::opmath_type; - auto alpha = value.to(); - AddcmulKernelFunctor f(alpha); - gpu_kernel(iter, f); - }); + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_xpu", [&]() { + auto alpha = value.to(); + gpu_kernel(iter, AddcmulComplexFunctor(alpha)); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "addcmul_xpu", + [&]() { + using accscalar_t = at::acc_type; + auto alpha = value.to(); + gpu_kernel(iter, AddcmulFunctor(alpha)); + }); + } +} + +template +struct AddcdivFunctor { + using accscalar_t = at::acc_type; + scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { + return a + alpha_ * (b / static_cast(c)); + } + + AddcdivFunctor(accscalar_t alpha) : alpha_(alpha) {} + + private: + accscalar_t alpha_; +}; + +template +struct AddcdivComplexFunctor { + scalar_t operator()(scalar_t a, scalar_t b, scalar_t c) const { + return a + alpha_ * (b / c); + } + + AddcdivComplexFunctor(scalar_t alpha) : alpha_(alpha) {} + + private: + scalar_t alpha_; +}; + +void addcdiv_kernel(TensorIterator& iter, Scalar value) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_xpu", [&]() { + auto alpha = value.to(); + AddcdivComplexFunctor f(alpha); + gpu_kernel(iter, f); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "addcdiv_xpu", + [&]() { + using accscalar_t = at::acc_type; + auto alpha = value.to(); + AddcdivFunctor f(alpha); + gpu_kernel(iter, f); + }); + } } template diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h index fdb216dbd..c775b88e5 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h @@ -6,6 +6,8 @@ namespace at::native::xpu { void addcmul_kernel(TensorIterator& iter, Scalar value); +void addcdiv_kernel(TensorIterator& iter, Scalar value); + void mse_backward_kernel(TensorIterator& iter, const Scalar& value); } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp b/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp index e082096c1..87de57a3a 100644 --- a/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp @@ -30,6 +30,32 @@ void conj_kernel(TensorIterator& iter) { })); } +template +struct ConjPhysicalFunctor { + scalar_t operator()(scalar_t z) const { + return std::conj(z); + } +}; + +template +struct ConjPhysicalFunctor> { + c10::complex operator()(c10::complex z) const { + return c10::complex(z.real(), -z.imag()); + } +}; + +void conj_physical_kernel(TensorIterator& iter) { + AT_DISPATCH_SWITCH( + iter.common_dtype(), + "conj_xpu", + AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] { + // Conj is a no-op for non-complex types + copy_kernel(iter); + }) AT_DISPATCH_CASE_COMPLEX_TYPES_AND(kComplexHalf, [&] { + gpu_kernel(iter, ConjPhysicalFunctor()); + })); +} + template struct NegConjScalarFunc { scalar_t operator()(scalar_t src_val) const { diff --git a/src/ATen/native/xpu/sycl/UnaryComplexKernels.h b/src/ATen/native/xpu/sycl/UnaryComplexKernels.h index 8d19381b3..d3ad4fe15 100644 --- a/src/ATen/native/xpu/sycl/UnaryComplexKernels.h +++ b/src/ATen/native/xpu/sycl/UnaryComplexKernels.h @@ -6,6 +6,8 @@ namespace at::native::xpu { void conj_kernel(TensorIterator& iter); +void conj_physical_kernel(TensorIterator& iter); + void neg_conj_kernel(TensorIterator& iter); void neg_kernel(TensorIterator& iter); diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index c7c2ff404..943d46465 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -128,6 +128,18 @@ # Greatest absolute difference: 0.03125 at index (610,) (up to 0.001 allowed) # Greatest relative difference: 0.00396728515625 at index (610,) (up to 0.001 allowed) "test_compare_cpu_hypot_xpu_bfloat16", + + # Regressions due to PyTorch uplift (Numeric difference in float and bfloat) + # https://github.com/intel/torch-xpu-ops/issues/549 + # Example fail log + # FAILED test_ops_xpu.py::TestCommonXPU::test_compare_cpu_nn_functional_batch_norm_xpu_float16 - AssertionError: Tensor-likes are not close! + # Mismatched elements: 3 / 72 (4.2%) + # Greatest absolute difference: 0.0029296875 at index (0, 1, 1, 0) (up to 0.001 allowed) + # Greatest relative difference: 0.0032501220703125 at index (2, 1, 2, 1) (up to 0.001 allowed) + "test_compare_cpu_nn_functional_batch_norm_xpu_float16", + "test_compare_cpu_std_mean_xpu_bfloat16", + "test_compare_cpu_sub_xpu_float16", + "test_compare_cpu_var_mean_xpu_bfloat16", ) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 43362fec5..4a146e0e0 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -207,7 +207,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_python_ref_torch_fallback__refs_square_xpu_bool", "test_python_ref_torch_fallback__refs_vdot_xpu_complex128", "test_python_ref_torch_fallback__refs_vdot_xpu_complex64", - "test_variant_consistency_eager_conj_physical_xpu_complex64", "test_variant_consistency_eager_nn_functional_conv_transpose2d_xpu_complex64", "test_variant_consistency_eager_nn_functional_conv_transpose2d_xpu_float32", "test_variant_consistency_eager_nn_functional_conv_transpose3d_xpu_complex64", @@ -242,8 +241,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_python_ref_executor__refs_square_executor_aten_xpu_complex128", "test_python_ref_torch_fallback__refs_square_xpu_complex128", "test_python_ref_torch_fallback__refs_square_xpu_complex64", - "test_conj_view_conj_physical_xpu_complex64", - "test_neg_conj_view_conj_physical_xpu_complex128", # Skip list of new added when porting XPU operators. # See: https://github.com/intel/torch-xpu-ops/issues/128 @@ -1291,9 +1288,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): # NotImplementedError: Could not run 'aten::_indices' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). "test_EmbeddingBag_sparse_cuda", "test_Embedding_sparse_cuda", - # col2im: AssertionError: The values for attribute 'shape' do not match: torch.Size([16, 4]) != torch.Size([1, 16, 4]). - "test_Fold_no_batch_dim_input_cuda", # col2im - "test_Fold_no_batch_dim_int_input_cuda", # AssertionError: 'XPU error: device-side assert triggered' not found in ' File "", line 8\n def test_cross_entropy_loss_2d_out_of_bounds_class_index(self):\n ^\nIndentationError: expected an indented block\n' "test_cross_entropy_loss_2d_out_of_bounds_class_index_xpu_float16", "test_cross_entropy_loss_2d_out_of_bounds_class_index_xpu_float32", @@ -2206,9 +2200,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): # torch.autograd.gradcheck.GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0, "test_fn_fwgrad_bwgrad_nn_functional_rrelu_xpu_float64", "test_forward_mode_AD_nn_functional_rrelu_xpu_float64", - # RuntimeError: DispatchStub: unsupported device typexpu - "test_inplace_forward_mode_AD_conj_physical_xpu_complex128", - # NotImplementedError: Could not run 'aten::_to_dense' with arguments from the 'SparseXPU' backend. +# NotImplementedError: Could not run 'aten::_to_dense' with arguments from the 'SparseXPU' backend. "test_fn_fwgrad_bwgrad_to_sparse_xpu_float64", "test_forward_mode_AD_to_sparse_xpu_float64", ) @@ -2744,9 +2736,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): ### Error #7 in TestBwdGradientsXPU , totally 2 , NotImplementedError: Could not run 'aten::_sparse_coo_tensor_with_dims_and_tensors' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_sparse_coo_tensor_with_dims_and_tensors' is only available for these backends: [XPU, Meta, SparseCPU, SparseMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher]. "test_fn_grad_to_sparse_xpu_float64", "test_fn_gradgrad_to_sparse_xpu_float64", - ### Error #8 in TestBwdGradientsXPU , totally 2 , RuntimeError: DispatchStub: unsupported device typexpu - "test_inplace_grad_conj_physical_xpu_complex128", - "test_inplace_gradgrad_conj_physical_xpu_complex128", ) res += launch_test("test_ops_gradients_xpu.py", skip_list) @@ -3005,23 +2994,21 @@ def launch_test(test_case, skip_list=None, exe_list=None): res += launch_test("nn/test_convolution_xpu.py", skip_list) # test_dynamic_shapes - - -res += launch_test("test_dynamic_shapes_xpu.py") +skip_list = ( + # Regression after PyTorch uplift + # https://github.com/intel/torch-xpu-ops/issues/549 + # AssertionError: 3 != 3.0 + "test_symnode_hashing", +) +res += launch_test("test_dynamic_shapes_xpu.py", skip_list) # test_load_state_dict - - res += launch_test("nn/test_load_state_dict_xpu.py") # test_module_hooks - - res += launch_test("nn/test_module_hooks_xpu.py") # test_parametrization - - res += launch_test("nn/test_parametrization_xpu.py") exit_code = os.WEXITSTATUS(res) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 0be80b6de..e68df6e1e 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -44,6 +44,7 @@ "bitwise_or", "bitwise_xor", "addcmul", + "addcdiv", "clamp", "clamp_max", "clamp_min", @@ -161,6 +162,7 @@ "bincount", "renorm", "lerp", + "conj_physical", ] diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index b93aa3602..b48bc8769 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -455,6 +455,7 @@ supported: - _cdist_forward - _pin_memory - is_pinned + - is_set_to - bucketize.Tensor - bucketize.Tensor_out - bucketize.Scalar @@ -496,6 +497,9 @@ supported: - avg_pool2d.out - avg_pool2d_backward - avg_pool2d_backward.grad_input + - addcdiv.out + - addcdiv + - addcdiv_ - addcmul.out - addcmul - addcmul_ @@ -520,6 +524,8 @@ supported: - randperm.generator_out - _amp_foreach_non_finite_check_and_unscale_ - _amp_update_scale_ + - conj_physical.out + - conj_physical_ - ceil - ceil_ - ceil.out