diff --git a/.github/actions/inductor-xpu-e2e-test/action.yml b/.github/actions/inductor-xpu-e2e-test/action.yml index 5647da77e..6e1dd4268 100644 --- a/.github/actions/inductor-xpu-e2e-test/action.yml +++ b/.github/actions/inductor-xpu-e2e-test/action.yml @@ -110,14 +110,11 @@ runs: contains "accuracy,performance" $scenario $contains_status if [ "${MODEL_ONLY_NAME}" == "" ];then - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 0 static 8 0 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 1 static 8 1 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 2 static 8 2 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 3 static 8 3 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 4 static 8 4 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 5 static 8 5 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 6 static 8 6 & - bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 7 static 8 7 & + xpu_list=($(xpu-smi discovery |grep 'DRM Device: /dev/' |sed 's/.*card//;s/[^0-9].*//' |awk '{print $1 - 1":"NR - 1}')) + for xpu_id in ${xpu_list[*]} + do + bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu ${xpu_id/:*} static ${#xpu_list[*]} ${xpu_id/*:} & + done else bash inductor_xpu_test.sh ${suite} ${dt} ${mode} ${scenario} xpu 0 static 1 0 ${MODEL_ONLY_NAME} & fi diff --git a/.github/ci_expected_accuracy/inductor_timm_models_inference.csv b/.github/ci_expected_accuracy/inductor_timm_models_inference.csv index ea6cbfdb0..ffaa26f62 100644 --- a/.github/ci_expected_accuracy/inductor_timm_models_inference.csv +++ b/.github/ci_expected_accuracy/inductor_timm_models_inference.csv @@ -37,7 +37,7 @@ mobilevit_s,pass,pass,pass,pass,pass nfnet_l0,pass,pass,pass,pass,pass pit_b_224,pass,pass,pass,pass,pass pnasnet5large,pass,pass,pass,pass,pass -poolformer_m36,pass,pass,fail_accuracy,pass,pass +poolformer_m36,pass,pass,pass,pass,pass regnety_002,pass,pass,pass,pass,pass repvgg_a2,pass,pass,pass,pass,pass res2net101_26w_4s,pass,pass,pass,pass,pass diff --git a/.github/ci_expected_accuracy/inductor_timm_models_training.csv b/.github/ci_expected_accuracy/inductor_timm_models_training.csv index 8a173f98f..df939e7db 100644 --- a/.github/ci_expected_accuracy/inductor_timm_models_training.csv +++ b/.github/ci_expected_accuracy/inductor_timm_models_training.csv @@ -36,8 +36,8 @@ mobilenetv3_large_100,pass,pass,fail_accuracy,pass,pass mobilevit_s,pass,pass,fail_accuracy,pass,pass nfnet_l0,pass,pass,pass,pass,pass pit_b_224,pass,pass,pass,pass,pass -pnasnet5large,pass,pass,fail_accuracy,pass,fail_accuracy -poolformer_m36,pass,pass,fail_accuracy,pass,pass +pnasnet5large,pass,pass,pass,pass,fail_accuracy +poolformer_m36,pass,pass,pass,pass,pass regnety_002,pass,pass,fail_accuracy,pass,pass repvgg_a2,pass,pass,fail_accuracy,pass,pass res2net101_26w_4s,pass,pass,fail_accuracy,pass,pass diff --git a/.github/ci_expected_accuracy/inductor_torchbench_inference.csv b/.github/ci_expected_accuracy/inductor_torchbench_inference.csv index 931fb45b3..7daebdbd4 100644 --- a/.github/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/.github/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -5,7 +5,7 @@ Background_Matting,pass_due_to_skip,pass_due_to_skip,eager_fail_to_run,pass_due_ DALLE2_pytorch,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run LearningToPaint,pass,pass,pass,pass,pass Super_SloMo,pass,pass,pass,pass,pass -alexnet,eager_two_runs_differ,pass,eager_two_runs_differ,pass,eager_two_runs_differ +alexnet,eager_two_runs_differ,pass,pass,pass,eager_two_runs_differ basic_gnn_edgecnn,pass,pass,pass,pass,pass basic_gnn_gcn,pass,pass,pass,pass,pass basic_gnn_gin,pass,pass,pass,pass,pass @@ -20,11 +20,11 @@ detectron2_fasterrcnn_r_101_fpn,pass,eager_fail_to_run,fail_accuracy,eager_fail_ detectron2_fasterrcnn_r_50_c4,pass,eager_fail_to_run,fail_accuracy,eager_fail_to_run,fail_accuracy detectron2_fasterrcnn_r_50_dc5,pass,eager_fail_to_run,fail_accuracy,eager_fail_to_run,fail_accuracy detectron2_fasterrcnn_r_50_fpn,pass,eager_fail_to_run,fail_accuracy,eager_fail_to_run,fail_accuracy -detectron2_fcos_r_50_fpn,pass,fail_accuracy,fail_accuracy,pass,fail_accuracy +detectron2_fcos_r_50_fpn,pass,pass,pass,pass,pass detectron2_maskrcnn,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run detectron2_maskrcnn_r_101_c4,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run detectron2_maskrcnn_r_101_fpn,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run -detectron2_maskrcnn_r_50_c4,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run +detectron2_maskrcnn_r_50_c4,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run detectron2_maskrcnn_r_50_fpn,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run dlrm,pass,pass,pass,pass,pass doctr_det_predictor,pass,pass,pass,eager_fail_to_run,pass @@ -61,7 +61,7 @@ mnasnet1_0,pass,pass,pass,pass,pass mobilenet_v2,pass,pass,pass,pass,pass mobilenet_v2_quantized_qat,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load mobilenet_v3_large,pass,pass,pass,pass,pass -moco,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run +moco,model_fail_to_load,model_fail_to_load,model_fail_to_load,eager_fail_to_run,model_fail_to_load moondream,pass,pass,pass,pass,pass nanogpt,pass,pass,pass,pass,pass nvidia_deeprecommender,pass,pass,pass,pass,pass @@ -89,7 +89,7 @@ speech_transformer,pass,pass,pass,pass,pass squeezenet1_1,pass,fail_accuracy,fail_accuracy,pass,pass stable_diffusion_text_encoder,pass,pass,pass,pass,pass stable_diffusion_unet,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip -tacotron2,pass,pass,pass,model_fail_to_load,model_fail_to_load +tacotron2,pass,pass,pass,model_fail_to_load,fail_to_run timm_efficientdet,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load timm_efficientnet,pass,pass,pass,pass,pass timm_nfnet,pass,pass,pass,pass,pass @@ -98,8 +98,8 @@ timm_resnest,pass,pass,pass,pass,pass timm_vision_transformer,pass,pass,pass,pass,pass timm_vision_transformer_large,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip timm_vovnet,pass,pass,pass,pass,pass -torch_multimodal_clip,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run +torch_multimodal_clip,pass,pass,pass,eager_fail_to_run,eager_fail_to_run tts_angular,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run -vgg16,eager_two_runs_differ,pass,eager_two_runs_differ,pass,pass +vgg16,eager_two_runs_differ,pass,pass,pass,pass vision_maskrcnn,pass,pass,pass,eager_fail_to_run,eager_fail_to_run yolov3,pass,pass,pass,pass,pass diff --git a/.github/ci_expected_accuracy/inductor_torchbench_training.csv b/.github/ci_expected_accuracy/inductor_torchbench_training.csv index b660a0829..94868d276 100644 --- a/.github/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/.github/ci_expected_accuracy/inductor_torchbench_training.csv @@ -1,5 +1,5 @@ name,float32,bfloat16,float16,amp_bf16,amp_fp16 -torchrec_dlrm,fail_to_run,eager_fail_to_run,eager_fail_to_run,fail_to_run,fail_to_run +torchrec_dlrm,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,fail_to_run BERT_pytorch,pass,pass,pass,pass,pass Background_Matting,pass_due_to_skip,pass_due_to_skip,eager_fail_to_run,pass_due_to_skip,eager_fail_to_run DALLE2_pytorch,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run @@ -53,7 +53,7 @@ hf_distil_whisper,model_fail_to_load,model_fail_to_load,model_fail_to_load,model lennard_jones,pass,pass,pass,pass,pass llama,pass,pass,pass,pass,pass llama_v2_7b_16h,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip -llava,eager_fail_to_run,eager_2nd_run_fail,eager_2nd_run_fail,eager_fail_to_run,eager_fail_to_run +llava,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run maml,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run maml_omniglot,pass,pass,pass,pass,pass microbench_unbacked_tolist_sum,pass,pass,pass,pass,pass @@ -61,7 +61,7 @@ mnasnet1_0,pass,pass,pass,pass,pass mobilenet_v2,pass,pass,pass,pass,pass mobilenet_v2_quantized_qat,fail_accuracy,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run mobilenet_v3_large,pass,pass,pass,pass,pass -moco,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run +moco,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,eager_fail_to_run moondream,pass,pass,pass,pass,pass nanogpt,pass,pass,pass,pass,pass nvidia_deeprecommender,pass,pass,pass,pass,pass @@ -91,14 +91,14 @@ stable_diffusion_text_encoder,pass,pass,pass,pass,pass stable_diffusion_unet,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip tacotron2,fail_to_run,fail_to_run,fail_to_run,fail_to_run,fail_to_run timm_efficientdet,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load,model_fail_to_load -timm_efficientnet,pass,pass,pass,fail_accuracy,pass +timm_efficientnet,pass,pass,pass,pass,pass timm_nfnet,pass,pass,pass,pass,pass timm_regnet,pass,pass,pass,pass,pass timm_resnest,pass,pass,pass,pass,pass timm_vision_transformer,pass,pass,pass,pass,pass timm_vision_transformer_large,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip,pass_due_to_skip timm_vovnet,pass,pass,pass,pass,pass -torch_multimodal_clip,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run +torch_multimodal_clip,pass,pass,pass,eager_fail_to_run,eager_fail_to_run tts_angular,pass,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run,eager_fail_to_run vgg16,eager_two_runs_differ,eager_two_runs_differ,eager_two_runs_differ,eager_two_runs_differ,eager_two_runs_differ vision_maskrcnn,pass,pass,pass,eager_fail_to_run,eager_fail_to_run diff --git a/.github/scripts/apply_torch_pr.py b/.github/scripts/apply_torch_pr.py index 2377b9826..d0ab9a163 100644 --- a/.github/scripts/apply_torch_pr.py +++ b/.github/scripts/apply_torch_pr.py @@ -15,6 +15,8 @@ "https://github.com/pytorch/pytorch/pull/127277", # [Inductor][Intel GPU] Support reduction split. "https://github.com/pytorch/pytorch/pull/129120", + # Modify the tolerance level in TIMM benchmark + "https://github.com/pytorch/pytorch/pull/129735", ] ) parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[]) diff --git a/.github/scripts/inductor_xpu_test.sh b/.github/scripts/inductor_xpu_test.sh index 77c4a5de1..2f22686fe 100644 --- a/.github/scripts/inductor_xpu_test.sh +++ b/.github/scripts/inductor_xpu_test.sh @@ -61,5 +61,5 @@ fi ulimit -n 1048576 ZE_AFFINITY_MASK=${CARD} \ python benchmarks/dynamo/${SUITE}.py --${SCENARIO} --${Real_DT} -d ${DEVICE} -n10 --no-skip --dashboard \ - ${DT_extra} ${Mode_extra} ${Shape_extra} ${partition_flags} ${Model_only_extra} --backend=inductor --timeout=7200 \ + ${DT_extra} ${Mode_extra} ${Shape_extra} ${partition_flags} ${Model_only_extra} --backend=inductor --timeout=10800 \ --output=${LOG_DIR}/${LOG_NAME}.csv 2>&1 | tee ${LOG_DIR}/${LOG_NAME}_card${CARD}.log diff --git a/.github/workflows/inductor_xpu_e2e_ci.yml b/.github/workflows/inductor_xpu_e2e_ci.yml index 79be206ce..c7d408b33 100644 --- a/.github/workflows/inductor_xpu_e2e_ci.yml +++ b/.github/workflows/inductor_xpu_e2e_ci.yml @@ -126,7 +126,7 @@ jobs: cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true) if [ ${failed_case} -ne 0 ];then - grep -E "Failed: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log + grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log exit 1 fi - name: Upload Inductor XPU E2E Data diff --git a/.github/workflows/inductor_xpu_e2e_nightly.yml b/.github/workflows/inductor_xpu_e2e_nightly.yml index 7efa4c032..8307edae7 100644 --- a/.github/workflows/inductor_xpu_e2e_nightly.yml +++ b/.github/workflows/inductor_xpu_e2e_nightly.yml @@ -234,7 +234,7 @@ jobs: cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files failed_case=$(grep "Real failed: models: *[1-9]" ${{ github.workspace }}/upload_files/summary_accuracy.log |wc -l || true) if [ ${failed_case} -ne 0 ];then - grep -E "Failed: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log + grep -E "Real failed: models: [1-9]|Summary for" ${{ github.workspace }}/summary_accuracy.log exit 1 fi - name: Upload Inductor XPU E2E Data @@ -260,19 +260,19 @@ jobs: # Test env build_url="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" repo="${{ github.repository }}" - TORCH_BRANCH_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_BRANCH_ID }} - TORCH_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_COMMIT_ID }} - DRIVER_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.DRIVER_VERSION }} - BUNDLE_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.BUNDLE_VERSION }} - OS_PRETTY_NAME=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.OS_PRETTY_NAME }} - GCC_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.GCC_VERSION }} - TORCHBENCH_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHBENCH_COMMIT_ID }} - TORCHVISION_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHVISION_COMMIT_ID }} - TORCHAUDIO_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHAUDIO_COMMIT_ID }} - # TORCHTEXT_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHTEXT_COMMIT_ID }} - TRANSFORMERS_VERSION=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRANSFORMERS_VERSION }} - TIMM_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TIMM_COMMIT_ID }} - TRITON_COMMIT_ID=${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRITON_COMMIT_ID }} + TORCH_BRANCH_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_BRANCH_ID }}" + TORCH_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCH_COMMIT_ID }}" + DRIVER_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.DRIVER_VERSION }}" + BUNDLE_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.BUNDLE_VERSION }}" + OS_PRETTY_NAME="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.OS_PRETTY_NAME }}" + GCC_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.GCC_VERSION }}" + TORCHBENCH_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHBENCH_COMMIT_ID }}" + TORCHVISION_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHVISION_COMMIT_ID }}" + TORCHAUDIO_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHAUDIO_COMMIT_ID }}" + # TORCHTEXT_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TORCHTEXT_COMMIT_ID }}" + TRANSFORMERS_VERSION="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRANSFORMERS_VERSION }}" + TIMM_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TIMM_COMMIT_ID }}" + TRITON_COMMIT_ID="${{ needs.Inductor-XPU-E2E-Nightly-Tests.outputs.TRITON_COMMIT_ID }}" # Test status if [ "${{ needs.Inductor-XPU-E2E-Nightly-Tests.result }}" == "success" ];then test_status=Success diff --git a/src/ATen/native/xpu/Activation.cpp b/src/ATen/native/xpu/Activation.cpp index 18fed3eff..ee673fef6 100644 --- a/src/ATen/native/xpu/Activation.cpp +++ b/src/ATen/native/xpu/Activation.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include namespace at { @@ -508,4 +510,126 @@ Tensor& XPUNativeFunctions::leaky_relu_backward_out( return grad_input; } +TensorIterator softplus_meta( + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + Tensor& out) { + return TensorIterator::unary_op(out, self); +} + +Tensor XPUNativeFunctions::softplus( + const Tensor& self, + const Scalar& beta, + const Scalar& threshold) { + Tensor out; + auto iter = softplus_meta(self, beta, threshold, out); + native::xpu::softplus_kernel(iter, beta, threshold); + return iter.output(); +} + +Tensor& XPUNativeFunctions::softplus_out( + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + Tensor& out) { + auto iter = softplus_meta(self, beta, threshold, out); + native::xpu::softplus_kernel(iter, beta, threshold); + return out; +} + +TensorIterator softplus_backward_meta( + const Tensor& grad_output, + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + Tensor& grad_input) { + return TensorIterator::borrowing_binary_op(grad_input, grad_output, self); +} + +Tensor XPUNativeFunctions::softplus_backward( + const Tensor& grad_output, + const Tensor& self, + const Scalar& beta, + const Scalar& threshold) { + Tensor grad_input; + auto iter = + softplus_backward_meta(grad_output, self, beta, threshold, grad_input); + native::xpu::softplus_backward_kernel(iter, beta, threshold); + return iter.output(); +} + +Tensor& XPUNativeFunctions::softplus_backward_out( + const Tensor& grad_output, + const Tensor& self, + const Scalar& beta, + const Scalar& threshold, + Tensor& grad_input) { + auto iter = + softplus_backward_meta(grad_output, self, beta, threshold, grad_input); + native::xpu::softplus_backward_kernel(iter, beta, threshold); + return grad_input; +} + +static inline void softshrink_check(const Scalar& lambd) { + double lamb = lambd.to(); + TORCH_CHECK( + lamb >= 0, + "lambda must be greater or equal to 0, but found to be ", + lamb, + "."); +} + +TensorIterator softshrink_meta( + const Tensor& self, + const Scalar& lambd, + Tensor& out) { + softshrink_check(lambd); + return TensorIterator::unary_op(out, self); +} + +Tensor XPUNativeFunctions::softshrink(const Tensor& self, const Scalar& lambd) { + Tensor out; + auto iter = softshrink_meta(self, lambd, out); + native::xpu::softshrink_kernel(iter, lambd); + return iter.output(); +} + +Tensor& XPUNativeFunctions::softshrink_out( + const Tensor& self, + const Scalar& lambd, + Tensor& out) { + auto iter = softshrink_meta(self, lambd, out); + native::xpu::softshrink_kernel(iter, lambd); + return out; +} + +TensorIterator softshrink_backward_meta( + const Tensor& grad_output, + const Tensor& self, + const Scalar& lambd, + Tensor& grad_input) { + return TensorIterator::borrowing_binary_op(grad_input, grad_output, self); +} + +Tensor XPUNativeFunctions::softshrink_backward( + const Tensor& grad_output, + const Tensor& self, + const Scalar& lambd) { + Tensor grad_input; + auto iter = softshrink_backward_meta(grad_output, self, lambd, grad_input); + native::xpu::softshrink_backward_kernel(iter, lambd); + return iter.output(); +} + +Tensor& XPUNativeFunctions::softshrink_backward_out( + const Tensor& grad_output, + const Tensor& self, + const Scalar& lambd, + Tensor& grad_input) { + auto iter = softshrink_backward_meta(grad_output, self, lambd, grad_input); + native::xpu::softshrink_backward_kernel(iter, lambd); + return grad_input; +} + } // namespace at diff --git a/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp b/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp index 3054992e8..b09d1c8c0 100644 --- a/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp +++ b/src/ATen/native/xpu/AdaptiveAveragePooling2d.cpp @@ -1,32 +1,172 @@ #include +#include #include -#include +#include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +#include + namespace at { +namespace { + +static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) { + c10::SymInt size = 1; + if (sizes.empty()) { + return 1; + } + for (auto d : dim) { + d = at::maybe_wrap_dim(d, static_cast(sizes.size())); + size *= sizes[d]; + } + return size; +} + +Tensor unsqueeze_multiple( + const Tensor& t, + OptionalIntArrayRef opt_dim, + size_t n_dims) { + if (opt_dim.has_value()) { + IntArrayRef dim = opt_dim.value(); + auto dim_size = dim.size(); + // Optimisation for two common cases + if (dim_size == 0) { + return t; + } else if (dim_size == 1) { + return t.unsqueeze(dim[0]); + } + } + auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims); + Tensor res = t; + for (const auto i : c10::irange(n_dims)) { + if (dims_to_unsqueeze[i]) { + res = res.unsqueeze(static_cast(i)); + } + } + return res; +} + +Tensor sum_backward( + const Tensor& grad, + c10::SymIntArrayRef sizes, + OptionalIntArrayRef opt_dims, + bool keepdim) { + if (!keepdim && !sizes.empty()) { + if (opt_dims.has_value() && !opt_dims.value().empty()) { + return unsqueeze_multiple(grad, opt_dims, sizes.size()) + .expand_symint(sizes); + } + } + return grad.expand_symint(sizes); +} + +Tensor mean_backward( + const Tensor& grad, + c10::SymIntArrayRef shape, + OptionalIntArrayRef opt_dim, + c10::SymInt numel, + bool keepdim) { + bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty(); + auto n = + is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value()); + return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n); +} +} // namespace + Tensor XPUNativeFunctions::_adaptive_avg_pool2d_backward( - const Tensor& grad_output_, - const Tensor& input_) { + const Tensor& grad_output, + const Tensor& input) { + TensorArg grad_output_arg{grad_output, "grad_output", 1}, + input_arg{input, "input", 2}; + + native::adaptive_pool_empty_output_check( + grad_output, "adaptive_avg_pool2d_backward"); + + checkAllSameGPU(__func__, {grad_output_arg, input_arg}); + + TORCH_CHECK( + (input.ndimension() == 3 || input.ndimension() == 4), + "non-empty 3D or 4D (batch mode) tensor expected for input"); + + if (grad_output.size(-1) == 1 && grad_output.size(-2) == 1) { + return mean_backward( + grad_output, + input.sym_sizes().vec(), + {-1, -2}, + input.sym_numel(), + true); + } + + globalContext().alertNotDeterministic("_adaptive_avg_pool2d_backward"); + Tensor grad_input; - if (input_.numel() != 0) { - Tensor input, grad_output; - if (input_.ndimension() == 3) { - input = input_.contiguous(); - grad_output = grad_output_.contiguous(); - grad_input = at::empty_like(input); - } else { - auto smf = input_.suggest_memory_format(); - input = input_.contiguous(smf); - grad_output = grad_output_.contiguous(smf); - grad_input = at::empty_like(input_, smf); - } - native::xpu::adaptive_avg_pool2d_backward_out_kernel( + if (input.numel() != 0) { + native::xpu::adaptive_avg_pool2d_backward_kernel( grad_input, grad_output, input); } else { - grad_input = at::zeros_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } + return grad_input; } +Tensor& XPUNativeFunctions::adaptive_avg_pool2d_out( + const Tensor& input, + IntArrayRef output_size, + Tensor& output) { + TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2}; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + TORCH_CHECK( + output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); + int64_t ndim = input.dim(); + TORCH_CHECK( + (ndim == 3 || ndim == 4), + "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ", + input.sizes()); + for (const auto i : {-2, -1}) { + TORCH_CHECK( + input.size(i) > 0, + "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + "but input has sizes ", + input.sizes(), + " with dimension ", + i + ndim, + " being " + "empty"); + } + + if (output_size[0] == 1 && output_size[1] == 1) { + if (output.numel() == 0) { + output = input.mean({-1, -2}, /* keepdim = */ true); + } else { + at::mean_out(output, input, {-1, -2}, true, std::nullopt); + } + if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) { + // assert ndim == 4, since ndim = 3 doesn't give channels_last + const auto n = input.sym_size(0); + const auto c = input.sym_size(1); + output.as_strided__symint({n, c, 1, 1}, {c, 1, c, c}); + } + } else { + native::xpu::adaptive_avg_pool2d_kernel(output, input, output_size); + } + return output; +} + +Tensor XPUNativeFunctions::_adaptive_avg_pool2d( + at::Tensor const& input, + IntArrayRef output_size) { + auto output = at::empty({0}, input.options()); + adaptive_avg_pool2d_out(input, output_size, output); + return output; +} + } // namespace at diff --git a/src/ATen/native/xpu/BatchNorm.cpp b/src/ATen/native/xpu/BatchNorm.cpp new file mode 100644 index 000000000..93018263d --- /dev/null +++ b/src/ATen/native/xpu/BatchNorm.cpp @@ -0,0 +1,383 @@ +#include +#include +#include +#include +#include +#include + +namespace at { + +std::tuple XPUNativeFunctions::batch_norm_stats( + const Tensor& input, + double eps) { + return native::xpu::batch_norm_stats_kernel(input, eps); +} + +Tensor XPUNativeFunctions::batch_norm_elemt( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + const Tensor& mean, + const Tensor& invstd, + double eps) { + auto output = at::empty_like(input); + native::xpu::batch_norm_elemt_kernel( + output, input, weight, bias, mean, invstd); + return output; +} + +Tensor& XPUNativeFunctions::batch_norm_elemt_out( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + const Tensor& mean, + const Tensor& invstd, + double eps, + Tensor& out) { + native::xpu::batch_norm_elemt_kernel(out, input, weight, bias, mean, invstd); + return out; +} + +std::tuple XPUNativeFunctions:: + batch_norm_backward_reduce( + const Tensor& grad_out, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const std::optional& weight, + bool input_g, + bool weight_g, + bool bias_g) { + return native::xpu::batch_norm_backward_reduce_kernel( + grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g); +} + +Tensor XPUNativeFunctions::batch_norm_backward_elemt( + const Tensor& grad_out, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const std::optional& weight, + const Tensor& sum_dy, + const Tensor& sum_dy_xmu, + const Tensor& count) { + return native::xpu::batch_norm_backward_elemt_kernel( + grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); +} + +std::tuple XPUNativeFunctions::batch_norm_update_stats( + const Tensor& input, + const std::optional& running_mean, + const std::optional& running_var, + double momentum) { + return native::xpu::batch_norm_update_stats_kernel( + input, running_mean, running_var, momentum); +} + +std::tuple XPUNativeFunctions::native_batch_norm( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + const std::optional& running_mean, + const std::optional& running_var, + bool training, + double momentum, + double eps) { + auto output = at::empty_like(input); + int64_t n_input = input.size(1); + auto options = + input.options().dtype(at::toAccumulateType(input.scalar_type(), true)); + auto save_mean = at::empty({n_input}, options); + auto save_invstd = at::empty({n_input}, options); + + native::xpu::batch_norm_kernel( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + output, + save_mean, + save_invstd); + + return std::make_tuple(output, save_mean, save_invstd); +} + +std::tuple XPUNativeFunctions::native_batch_norm_out( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + const std::optional& running_mean, + const std::optional& running_var, + bool training, + double momentum, + double eps, + Tensor& out, + Tensor& save_mean, + Tensor& save_invstd) { + return native::xpu::batch_norm_kernel( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + out, + save_mean, + save_invstd); +} + +std::tuple XPUNativeFunctions:: + native_batch_norm_backward( + const Tensor& grad_out, + const Tensor& input, + const std::optional& weight, + const std::optional& running_mean, + const std::optional& running_var, + const std::optional& save_mean, + const std::optional& save_invstd, + bool train, + double eps, + std::array output_mask) { + return native::xpu::batch_norm_backward_kernel( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask); +} + +std::tuple XPUNativeFunctions::_native_batch_norm_legit( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + Tensor& running_mean, + Tensor& running_var, + bool training, + double momentum, + double eps) { + return XPUNativeFunctions::native_batch_norm( + input, weight, bias, running_mean, running_var, training, momentum, eps); +} + +std::tuple XPUNativeFunctions:: + _native_batch_norm_legit_out( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + Tensor& running_mean, + Tensor& running_var, + bool training, + double momentum, + double eps, + Tensor& out, + Tensor& save_mean, + Tensor& save_invstd) { + return XPUNativeFunctions::native_batch_norm_out( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + out, + save_mean, + save_invstd); +} + +std::tuple XPUNativeFunctions::_native_batch_norm_legit( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + bool training, + double momentum, + double eps) { + return XPUNativeFunctions::native_batch_norm( + input, weight, bias, Tensor(), Tensor(), training, momentum, eps); +} + +std::tuple XPUNativeFunctions:: + _native_batch_norm_legit_out( + const at::Tensor& input, + const std::optional& weight, + const std::optional& bias, + bool training, + double momentum, + double eps, + at::Tensor& out, + at::Tensor& save_mean, + at::Tensor& save_invstd) { + return XPUNativeFunctions::native_batch_norm_out( + input, + weight, + bias, + Tensor(), + Tensor(), + training, + momentum, + eps, + out, + save_mean, + save_invstd); +} + +inline std::tuple batch_norm_with_update( + const Tensor& input, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + Tensor& running_mean, + Tensor& running_var, + double momentum, + double eps) { + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); }); + Tensor reserve; + + reserve = at::empty({0}, input.options().dtype(kByte)); + + auto output = at::empty_like(input); + int64_t n_input = input.size(1); + auto options = + input.options().dtype(at::toAccumulateType(input.scalar_type(), true)); + auto save_mean = at::empty({n_input}, options); + auto save_invstd = at::empty({n_input}, options); + + native::xpu::batch_norm_kernel( + input, + weight, + bias, + running_mean, + running_var, + /*training*/ true, + momentum, + eps, + output, + save_mean, + save_invstd); + + return std::tuple( + output, save_mean, save_invstd, reserve); +} + +inline std::tuple batch_norm_with_update_out( + const Tensor& input, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + Tensor& running_mean, + Tensor& running_var, + double momentum, + double eps, + Tensor& out, + Tensor& save_mean, + Tensor& save_var, + Tensor& reserve) { + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); }); + + std::tie(out, save_mean, save_var) = native::xpu::batch_norm_kernel( + input, + weight, + bias, + running_mean, + running_var, + /*update*/ true, + momentum, + eps, + out, + save_mean, + save_var); + + return std::tuple( + out, save_mean, save_var, reserve); +} + +std::tuple XPUNativeFunctions:: + _batch_norm_with_update( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + Tensor& running_mean, + Tensor& running_var, + double momentum, + double eps) { + return batch_norm_with_update( + input, weight, bias, running_mean, running_var, momentum, eps); +} + +std::tuple XPUNativeFunctions:: + _batch_norm_with_update_out( + const Tensor& input, + const std::optional& weight, + const std::optional& bias, + Tensor& running_mean, + Tensor& running_var, + double momentum, + double eps, + Tensor& out, + Tensor& save_mean, + Tensor& save_invstd, + Tensor& reserve) { + return batch_norm_with_update_out( + input, + weight, + bias, + running_mean, + running_var, + momentum, + eps, + out, + save_mean, + save_invstd, + reserve); +} + +std::tuple XPUNativeFunctions::batch_norm_backward( + const Tensor& grad_output, + const Tensor& input, + const Tensor& weight, + const std::optional& running_mean_opt, + const std::optional& running_var_opt, + const std::optional& save_mean_opt, + const std::optional& save_var_opt, + bool update, + double eps, + std::array grad_input_mask, + const Tensor& reserve) { + const Tensor& running_mean = + c10::value_or_else(running_mean_opt, [] { return Tensor(); }); + const Tensor& running_var = + c10::value_or_else(running_var_opt, [] { return Tensor(); }); + const Tensor& save_mean = + c10::value_or_else(save_mean_opt, [] { return Tensor(); }); + const Tensor& save_var = + c10::value_or_else(save_var_opt, [] { return Tensor(); }); + return native::xpu::batch_norm_backward_kernel( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + update, + eps, + grad_input_mask); +} + +} // namespace at diff --git a/src/ATen/native/xpu/BinaryOps.cpp b/src/ATen/native/xpu/BinaryOps.cpp index 9680c5073..c50f8305d 100644 --- a/src/ATen/native/xpu/BinaryOps.cpp +++ b/src/ATen/native/xpu/BinaryOps.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -380,6 +381,28 @@ Tensor& XPUNativeFunctions::gcd_out( return out; } +Tensor XPUNativeFunctions::hypot(const Tensor& self, const Tensor& other) { + Tensor out; + auto iter = TensorIterator::borrowing_binary_op(out, self, other); + native::xpu::hypot_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::hypot_(Tensor& self, const Tensor& other) { + auto iter = TensorIterator::borrowing_binary_op(self, self, other); + native::xpu::hypot_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::hypot_out( + const Tensor& self, + const Tensor& other, + Tensor& out) { + auto iter = TensorIterator::borrowing_binary_op(out, self, other); + native::xpu::hypot_kernel(iter); + return out; +} + static inline TensorIterator meta_func_maximum( const Tensor& self, const Tensor& other, diff --git a/src/ATen/native/xpu/Lerp.cpp b/src/ATen/native/xpu/Lerp.cpp new file mode 100644 index 000000000..272417b39 --- /dev/null +++ b/src/ATen/native/xpu/Lerp.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include + +#include + +namespace at { + +TensorIterator lerp_tensor_meta( + const Tensor& self, + const Tensor& end, + const Tensor& weight, + Tensor& out) { + TORCH_CHECK( + self.dtype() == end.dtype(), + "expected dtype ", + self.dtype(), + " for `end` but got dtype ", + end.dtype()); + TORCH_CHECK( + self.dtype() == weight.dtype(), + "expected dtype ", + self.dtype(), + " for `weight` but got dtype ", + weight.dtype()); + TensorIterator iter; + iter.build(TensorIteratorConfig() + .add_output(out) + .add_const_input(self) + .add_const_input(end) + .add_const_input(weight)); + return iter; +} + +Tensor XPUNativeFunctions::lerp( + const Tensor& self, + const Tensor& end, + const Tensor& weight) { + Tensor out; + auto iter = lerp_tensor_meta(self, end, weight, out); + native::xpu::lerp_tensor_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::lerp_( + Tensor& self, + const Tensor& end, + const Tensor& weight) { + auto iter = lerp_tensor_meta(self, end, weight, self); + native::xpu::lerp_tensor_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::lerp_out( + const Tensor& self, + const Tensor& end, + const Tensor& weight, + Tensor& out) { + auto iter = lerp_tensor_meta(self, end, weight, out); + native::xpu::lerp_tensor_kernel(iter); + return out; +} + +TensorIterator lerp_scalar_meta( + const Tensor& self, + const Tensor& end, + const Scalar& /*weight*/, + Tensor& out) { + TORCH_CHECK( + self.dtype() == end.dtype(), + "expected dtype ", + self.dtype(), + " for `end` but got dtype ", + end.dtype()); + TensorIterator iter; + iter.build_binary_op(out, self, end); + return iter; +} + +Tensor XPUNativeFunctions::lerp( + const Tensor& self, + const Tensor& end, + const Scalar& weight) { + Tensor out; + auto iter = lerp_scalar_meta(self, end, weight, out); + native::xpu::lerp_scalar_kernel(iter, weight); + return iter.output(); +} + +Tensor& XPUNativeFunctions::lerp_( + Tensor& self, + const Tensor& end, + const Scalar& weight) { + auto iter = lerp_scalar_meta(self, end, weight, self); + native::xpu::lerp_scalar_kernel(iter, weight); + return self; +} + +Tensor& XPUNativeFunctions::lerp_out( + const Tensor& self, + const Tensor& end, + const Scalar& weight, + Tensor& out) { + auto iter = lerp_scalar_meta(self, end, weight, out); + native::xpu::lerp_scalar_kernel(iter, weight); + return out; +} + +} // namespace at diff --git a/src/ATen/native/xpu/Resize.cpp b/src/ATen/native/xpu/Resize.cpp index 405b905bf..ad2a0b586 100644 --- a/src/ATen/native/xpu/Resize.cpp +++ b/src/ATen/native/xpu/Resize.cpp @@ -13,96 +13,13 @@ #include #endif -#include -#include -#include +#include namespace at { namespace native::xpu { extern Tensor& _copy_xpu(Tensor& self, const Tensor& src, bool non_blocking); -void resize_bytes_xpu(StorageImpl* storage, size_t size_bytes) { - TORCH_CHECK( - storage->resizable(), "Trying to resize storage that is not resizable"); - auto allocator = storage->allocator(); - TORCH_CHECK( - allocator != nullptr, "Trying to resize storage without an allocator"); - - c10::Device device = storage->device(); - - if (size_bytes == 0) { - storage->set_data_ptr_noswap(at::DataPtr(nullptr, device)); - storage->set_nbytes(0); - return; - } - - c10::xpu::XPUGuard guard(device.index()); - at::DataPtr data = allocator->allocate(size_bytes); - if (storage->data_ptr()) { - at::globalContext().lazyInitXPU(); - auto q = at::xpu::getCurrentSYCLQueue(); - - q.memcpy( - data.get(), storage->data(), std::min(storage->nbytes(), size_bytes)); - } - - // Destructively overwrite data_ptr - storage->set_data_ptr_noswap(std::move(data)); - storage->set_nbytes(size_bytes); -} - -static inline void maybe_resize_storage_xpu( - TensorImpl* self, - size_t new_size_bytes) { - // It does not make sense to try to resize a storage - // to hold 0 elements, and this can break - // if storage_offset is positive but - // new_size is 0, so just bail in that case - // (same comment is in Resize.h) - if (self->numel() == 0) { - return; - } - - const Storage& storage = self->unsafe_storage(); - TORCH_CHECK(storage, "Tensor: invalid null storage"); - if (new_size_bytes > storage.nbytes()) { - resize_bytes_xpu(storage.unsafeGetStorageImpl(), new_size_bytes); - } -} - -inline TensorImpl* resize_impl_xpu_( - TensorImpl* self, - IntArrayRef size, - at::OptionalIntArrayRef stride, - bool device_guard = true) { - if (self->sizes() == size && (!stride || self->strides() == stride)) { - return self; - } - - // NB: We don't need to hold the device guard when calling from TH - at::xpu::OptionalXPUGuard guard; - if (device_guard) { - guard.set_index(self->storage().device().index()); - } - - const auto itemsize = self->dtype().itemsize(); - const auto storage_offset = self->storage_offset(); - size_t storage_size = 1; - if (stride) { - self->set_sizes_and_strides(size, *stride); - storage_size = at::detail::computeStorageNbytes( - size, *stride, itemsize, storage_offset); - } else { - self->set_sizes_contiguous(size); - storage_size = at::detail::computeStorageNbytesContiguous( - size, itemsize, storage_offset); - } - maybe_resize_storage_xpu(self, storage_size); - - return self; -} - const Tensor& resize_xpu_( const Tensor& self, IntArrayRef size, diff --git a/src/ATen/native/xpu/TensorAdvancedIndexing.cpp b/src/ATen/native/xpu/TensorAdvancedIndexing.cpp index 2a2dd5f1e..62bbd353d 100644 --- a/src/ATen/native/xpu/TensorAdvancedIndexing.cpp +++ b/src/ATen/native/xpu/TensorAdvancedIndexing.cpp @@ -503,16 +503,7 @@ Tensor& XPUNativeFunctions::_index_put_impl_( } } - // Performance consideration: - // Avoid atomic operations when accumulating bf16 and hf16. No efficient - // atomic operation hardware support. We have to do CAS, whose performance - // is worse than deterministic implementation. - bool need_use_deterministic = (accumulate && - (self.scalar_type() == at::kBFloat16 || - self.scalar_type() == at::kHalf)) || - globalContext().deterministicAlgorithms(); - - if (need_use_deterministic) { + if (accumulate || globalContext().deterministicAlgorithms()) { TORCH_CHECK( value_.device() == self.device(), "expected device ", diff --git a/src/ATen/native/xpu/TensorFactories.cpp b/src/ATen/native/xpu/TensorFactories.cpp index 8b4f1b8d0..d7b79902f 100644 --- a/src/ATen/native/xpu/TensorFactories.cpp +++ b/src/ATen/native/xpu/TensorFactories.cpp @@ -1,6 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -11,6 +12,7 @@ #include #endif +#include #include #include @@ -63,6 +65,73 @@ Tensor XPUNativeFunctions::clone( return at::native::clone(self, memory_format); } +Tensor XPUNativeFunctions::_efficientzerotensor( + IntArrayRef size, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory) { + auto device_ = device_or_default(device); + if (!device_.has_index()) { + device_.set_index(c10::xpu::current_device()); + } + auto allocator = at::native::ZeroTensorAllocator(device_); + auto dtype_ = dtype_or_default(dtype); + auto zero_ks = at::DispatchKeySet(c10::DispatchKey::XPU) | + at::DispatchKeySet(c10::DispatchKey::ZeroTensor); + auto out = at::detail::empty_generic( + size, &allocator, zero_ks, dtype_, c10::nullopt); + return out; +} + +static void complex_check_floating(const Tensor& a, const Tensor& b) { + TORCH_CHECK( + (a.scalar_type() == kFloat || a.scalar_type() == kDouble || + a.scalar_type() == kHalf) && + (b.scalar_type() == kFloat || b.scalar_type() == kDouble || + b.scalar_type() == kHalf), + "Expected both inputs to be Half, Float or Double tensors but got ", + a.scalar_type(), + " and ", + b.scalar_type()); +} + +static void complex_check_dtype( + const Tensor& result, + const Tensor& a, + const Tensor& b) { + complex_check_floating(a, b); + TORCH_CHECK( + a.scalar_type() == b.scalar_type(), + "Expected object of scalar type ", + a.scalar_type(), + " but got scalar type ", + b.scalar_type(), + " for second argument"); + TORCH_CHECK( + result.scalar_type() == toComplexType(a.scalar_type()), + "Expected object of scalar type ", + toComplexType(a.scalar_type()), + " but got scalar type ", + result.scalar_type(), + " for argument 'out'"); +} + +Tensor& XPUNativeFunctions::complex_out( + const Tensor& real, + const Tensor& imag, + Tensor& result) { + complex_check_dtype(result, real, imag); + auto iter = TensorIteratorConfig() + .add_output(result) + .add_const_input(real) + .add_const_input(imag) + .check_all_same_dtype(false) + .build(); + native::xpu::complex_kernel(iter); + return result; +} + Tensor& XPUNativeFunctions::randperm_out( int64_t n, c10::optional generator, diff --git a/src/ATen/native/xpu/TriangluarOps.cpp b/src/ATen/native/xpu/TriangluarOps.cpp index 6b5428e6c..affba5665 100644 --- a/src/ATen/native/xpu/TriangluarOps.cpp +++ b/src/ATen/native/xpu/TriangluarOps.cpp @@ -68,4 +68,10 @@ Tensor& XPUNativeFunctions::triu_(Tensor& self, int64_t diagonal) { xpu::check_inplace(self, self.sizes(), self.options()); return triu_out(self, diagonal, self); } + +Tensor XPUNativeFunctions::trace(const Tensor& self) { + TORCH_CHECK(self.dim() == 2, "expected a matrix"); + return self.diagonal().sum(); +} + } // namespace at diff --git a/src/ATen/native/xpu/sycl/UpSample.h b/src/ATen/native/xpu/UpSample.h similarity index 97% rename from src/ATen/native/xpu/sycl/UpSample.h rename to src/ATen/native/xpu/UpSample.h index ffa8e9c56..5ca47c4d4 100644 --- a/src/ATen/native/xpu/sycl/UpSample.h +++ b/src/ATen/native/xpu/UpSample.h @@ -10,9 +10,9 @@ #include #include -namespace at::native { +namespace at::native::xpu { -inline std::array upsample_2d_common_check( +inline C10_UNUSED std::array upsample_2d_common_check( IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( @@ -49,7 +49,6 @@ inline std::array upsample_2d_common_check( return {nbatch, channels, output_height, output_width}; } -namespace xpu { inline size_t idx_cl( const size_t n, @@ -229,6 +228,4 @@ static scalar_t upsample_get_value_bounded( return data[batch][channel][access_y][access_x]; } -} // namespace xpu - -} // namespace at::native +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/UpSampleBicubic2d.cpp b/src/ATen/native/xpu/UpSampleBicubic2d.cpp index d59945135..509d6e449 100644 --- a/src/ATen/native/xpu/UpSampleBicubic2d.cpp +++ b/src/ATen/native/xpu/UpSampleBicubic2d.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include @@ -15,7 +15,7 @@ void upsample_bicubic2d_meta( std::optional scales_h, std::optional scales_w) { auto full_output_size = - native::upsample_2d_common_check(input.sizes(), output_size); + native::xpu::upsample_2d_common_check(input.sizes(), output_size); // Allow for empty batch size but not other dimensions TORCH_CHECK( diff --git a/src/ATen/native/xpu/UpSampleBilinear2d.cpp b/src/ATen/native/xpu/UpSampleBilinear2d.cpp new file mode 100644 index 000000000..f0ace4344 --- /dev/null +++ b/src/ATen/native/xpu/UpSampleBilinear2d.cpp @@ -0,0 +1,160 @@ +#include +#include +#include + +#include +#include +#include + +namespace at { + +void upsample_bilinear2d_meta( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + Tensor& output) { + auto full_output_size = + native::xpu::upsample_2d_common_check(input.sizes(), output_size); + + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + input.numel() != 0 || + c10::multiply_integers( + input.sizes().begin() + 1, input.sizes().end()), + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + + auto memory_format = input.suggest_memory_format(); + if (output.defined()) { + xpu::resize_out( + output, + full_output_size, + {}, + input.options().memory_format(memory_format)); + } else { + output = at::xpu::create_out( + full_output_size, {}, input.options().memory_format(memory_format)); + } +} + +void upsample_bilinear2d_backward_meta( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + Tensor& grad_input) { + auto full_output_size = + native::xpu::upsample_2d_common_check(input_size, output_size); + + TORCH_CHECK( + grad_output.dim() == 4, + "Expected grad_output to be a tensor of dimension 4 but got: dimension ", + grad_output.dim()); + + for (const auto i : c10::irange(4)) { + TORCH_CHECK( + grad_output.size(i) == full_output_size[i], + "Expected grad_output to have the same shape as output;", + " output.size(", + i, + ") = ", + full_output_size[i], + " but got grad_output.size(", + i, + ") = ", + grad_output.size(i)); + } + + auto memory_format = grad_output.suggest_memory_format(); + if (grad_input.defined()) { + xpu::resize_out( + grad_input, + input_size, + {}, + grad_output.options().memory_format(memory_format)); + } else { + grad_input = at::xpu::create_out( + input_size, {}, grad_output.options().memory_format(memory_format)); + } +} + +Tensor& XPUNativeFunctions::upsample_bilinear2d_out( + const Tensor& self, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + Tensor& output) { + upsample_bilinear2d_meta( + self, output_size, align_corners, scales_h, scales_w, output); + native::xpu::upsample_bilinear2d_out_kernel( + output, self, output_size, align_corners, scales_h, scales_w); + return output; +} + +Tensor XPUNativeFunctions::upsample_bilinear2d( + const Tensor& self, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w) { + Tensor output; + upsample_bilinear2d_out( + self, output_size, align_corners, scales_h, scales_w, output); + return output; +} + +Tensor& XPUNativeFunctions::upsample_bilinear2d_backward_out( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + Tensor& grad_input) { + globalContext().alertNotDeterministic("upsample_bilinear2d_backward_xpu"); + + upsample_bilinear2d_backward_meta( + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w, + grad_input); + + native::xpu::upsample_bilinear2d_backward_out_kernel( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); + return grad_input; +} + +Tensor XPUNativeFunctions::upsample_bilinear2d_backward( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + Tensor grad_input; + upsample_bilinear2d_backward_out( + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w, + grad_input); + return grad_input; +} + +} // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index c952b67ff..7ead65b03 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -157,7 +157,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { */ TORCH_LIBRARY_IMPL(aten, XPU, m) { std::vector fallback_list = { - "_adaptive_avg_pool2d", "_adaptive_avg_pool3d", "_adaptive_avg_pool3d_backward", "adaptive_max_pool2d_backward.grad_input", @@ -187,7 +186,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "cholesky", "cholesky_inverse", "_cholesky_solve_helper", - "complex.out", "conj_physical.out", "copysign.out", "cosh.out", @@ -200,7 +198,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "digamma.out", "dot", "_efficient_attention_forward", - "_efficientzerotensor", "_embedding_bag_dense_backward", "_embedding_bag_per_sample_weights_backward", "equal", @@ -231,7 +228,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "histc", "huber_loss", "huber_loss_backward.out", - "hypot.out", "i0.out", "igammac.out", "igamma.out", @@ -244,8 +240,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "isposinf.out", "kthvalue.values", "lcm.out", - "lerp.Scalar_out", - "lerp.Tensor_out", "lgamma.out", "linalg_cholesky_ex.L", "linalg_cross.out", @@ -301,8 +295,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "nanmedian.dim_values", "nansum", "nan_to_num.out", - "native_batch_norm", - "native_batch_norm_backward", "nextafter.out", "norm.out", "ormqr", @@ -338,10 +330,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "sinh.out", "smooth_l1_loss_backward.grad_input", "smooth_l1_loss.out", - "softplus_backward.grad_input", - "softplus.out", - "softshrink_backward.grad_input", - "softshrink.out", "special_airy_ai.out", "special_bessel_j0.out", "special_bessel_j1.out", @@ -381,7 +369,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "topk.values", "_to_sparse", "_to_sparse_csr", - "trace", "triangular_solve.X", "tril_indices", "triu_indices", @@ -390,8 +377,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "unique_consecutive", "upsample_bicubic2d_backward.grad_input", "_upsample_bilinear2d_aa.out", - "upsample_bilinear2d_backward.grad_input", - "upsample_bilinear2d.out", "upsample_linear1d_backward.grad_input", "upsample_linear1d.out", "upsample_nearest3d.out", diff --git a/src/ATen/native/xpu/sycl/ActivationSoftplusKernels.cpp b/src/ATen/native/xpu/sycl/ActivationSoftplusKernels.cpp new file mode 100644 index 000000000..99b1c5716 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationSoftplusKernels.cpp @@ -0,0 +1,80 @@ +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct SoftplusFunctor { + using opmath_t = at::opmath_type; + scalar_t operator()(scalar_t a) const { + opmath_t aop = static_cast(a); + return (aop * beta_) > threshold_ + ? aop + : (std::log1p(std::exp(aop * beta_))) / beta_; + } + + SoftplusFunctor(opmath_t beta, opmath_t threshold) + : beta_(beta), threshold_(threshold) {} + + private: + opmath_t beta_; + opmath_t threshold_; +}; + +void softplus_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_xpu", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + SoftplusFunctor f(beta, threshold); + gpu_kernel(iter, f); + }); +} + +template +struct SoftplusBackwardFunctor { + using opmath_t = at::opmath_type; + scalar_t operator()(scalar_t a, scalar_t b) const { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + opmath_t z = std::exp(bop * beta_); + return (bop * beta_) > threshold_ ? aop : aop * z / (z + opmath_t(1.)); + } + + SoftplusBackwardFunctor(opmath_t beta, opmath_t threshold) + : beta_(beta), threshold_(threshold) {} + + private: + opmath_t beta_; + opmath_t threshold_; +}; + +void softplus_backward_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_backward_xpu", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + SoftplusBackwardFunctor f(beta, threshold); + gpu_kernel(iter, f); + }); +} +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ActivationSoftplusKernels.h b/src/ATen/native/xpu/sycl/ActivationSoftplusKernels.h new file mode 100644 index 000000000..8a5e5ef2b --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationSoftplusKernels.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void softplus_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_); + +void softplus_backward_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp new file mode 100644 index 000000000..4393576e9 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.cpp @@ -0,0 +1,58 @@ +#include +#include + +#include + +namespace at::native::xpu { + +template +struct SoftshrinkFunctor { + scalar_t operator()(scalar_t a) const { + return a > lambd_ ? a - lambd_ : (a < -lambd_ ? a + lambd_ : scalar_t(0)); + } + + SoftshrinkFunctor(scalar_t lambd) : lambd_(lambd) {} + + private: + scalar_t lambd_; +}; + +void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softshrink_xpu", + [&]() { + auto lambd = value.to(); + SoftshrinkFunctor f(lambd); + gpu_kernel(iter, f); + }); +} + +template +struct SoftshrinkBackwardFunctor { + scalar_t operator()(scalar_t grad_val, scalar_t self_val) const { + return (self_val >= -lambd_ && self_val <= lambd_) ? scalar_t(0) : grad_val; + } + + SoftshrinkBackwardFunctor(scalar_t lambd) : lambd_(lambd) {} + + private: + scalar_t lambd_; +}; + +void softshrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "shrink_backward_xpu", + [&]() { + auto lambd = value.to(); + SoftshrinkBackwardFunctor f(lambd); + gpu_kernel(iter, f); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h new file mode 100644 index 000000000..481d1e5a1 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value); + +void softshrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernel.cpp b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernel.cpp deleted file mode 100644 index 108213065..000000000 --- a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernel.cpp +++ /dev/null @@ -1,322 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at::native::xpu { - -using namespace at::xpu; - -template -struct AdaptiveAvgPool2dBwdKernelFunctor { - void operator()(sycl::nd_item<1> item) const { - int64_t gi = item.get_global_linear_id(); - - for (int64_t i = gi; i < numel; i += global_range) { - int64_t _iw, _ih, _ic, _ib; - if constexpr (is_channels_last) { - _ic = i % ic; - _iw = i / ic % iw; - _ih = i / ic / iw % ih; - _ib = i / ic / iw / ih; - } else { - _iw = i % iw; - _ih = i / iw % ih; - _ic = i / iw / ih % ic; - _ib = i / iw / ih / ic; - } - - int64_t _oh0 = native::start_index(_ih, ih, oh); - int64_t _oh1 = native::end_index(_ih, ih, oh); - int64_t _ow0 = native::start_index(_iw, iw, ow); - int64_t _ow1 = native::end_index(_iw, iw, ow); - int64_t _ob = _ib; - int64_t _oc = _ic; - - accscalar_t gx = 0; - accscalar_t _ikh, _ikw; - for (int _oh = _oh0; _oh < _oh1; _oh++) { - _ikh = accscalar_t(1.0) / - (accscalar_t)(native::end_index(_oh, oh, ih) - native::start_index(_oh, oh, ih)); - for (int _ow = _ow0; _ow < _ow1; _ow++) { - _ikw = accscalar_t(1.0) / - (accscalar_t)(native::end_index(_ow, ow, iw) - native::start_index(_ow, ow, iw)); - gx += gyacc[_ob][_oc][_oh][_ow] * _ikh * _ikw; - } - } - - const auto store = [](PackedTensorAccessor64 gxacc, - int64_t _ib, - int64_t _ic, - int64_t _ih, - int64_t _iw, - scalar_t res) { gxacc[_ib][_ic][_ih][_iw] = res; }; - store(gxacc, _ib, _ic, _ih, _iw, (scalar_t)gx); - } - } - - AdaptiveAvgPool2dBwdKernelFunctor( - PackedTensorAccessor64 gyacc_, - PackedTensorAccessor64 gxacc_) - : gyacc(gyacc_), gxacc(gxacc_) { - ib = gxacc.size(0); - ic = gxacc.size(1); - ih = gxacc.size(2); - iw = gxacc.size(3); - oh = gyacc.size(2); - ow = gyacc.size(3); - - numel = ib * ic * ih * iw; - int total_item = std::min(numel, syclMaxWorkItemsPerTile()); - local_range = syclMaxWorkItemsPerEU(); - global_range = total_item < local_range - ? local_range - : (total_item / local_range) * local_range; - } - - sycl::range<1> glb_range() { - return sycl::range<1>(global_range); - } - - sycl::range<1> loc_range() { - return sycl::range<1>(local_range); - } - - private: - int ib; - int ic; - int ih; - int iw; - int oh; - int ow; - int64_t numel; - int global_range; - int local_range; - PackedTensorAccessor64 gyacc; - PackedTensorAccessor64 gxacc; -}; - -template -struct AdaptiveAvgPool2dBwdSLMKernelFunctor - : public __SYCL_KER_CONFIG_CONVENTION__ { - void operator()(sycl::nd_item<1> item) const { - int64_t gi = item.get_global_linear_id(); - int64_t li = item.get_local_id(0); - - // for-loop order: oh*ow->ih->iw - // reuse oh*ow(oh0, oh1, ow0, ow1), ih(ikh), iw(ikw) in inner loop. - for (int _ih = li; _ih < ih; _ih += local_range) { - _oh0_cached[_ih] = (int)native::start_index(_ih, ih, oh); - _oh1_cached[_ih] = (int)native::end_index(_ih, ih, oh); - } - for (int _iw = li; _iw < iw; _iw += local_range) { - _ow0_cached[_iw] = (int)native::start_index(_iw, iw, ow); - _ow1_cached[_iw] = (int)native::end_index(_iw, iw, ow); - } - for (int _oh = li; _oh < oh; _oh += local_range) { - _ikh_cached[_oh] = accscalar_t(1.0) / - (accscalar_t)(native::end_index(_oh, oh, ih) - - native::start_index(_oh, oh, ih)); - } - for (int _ow = li; _ow < ow; _ow += local_range) { - _ikw_cached[_ow] = accscalar_t(1.0) / - (accscalar_t)(native::end_index(_ow, ow, iw) - - native::start_index(_ow, ow, iw)); - } - - item.barrier(sycl_local_fence); - - for (int64_t i = gi; i < numel; i += global_range) { - int64_t _iw, _ih, _ic, _ib; - if constexpr (is_channels_last) { - _ic = i % ic; - _iw = i / ic % iw; - _ih = i / ic / iw % ih; - _ib = i / ic / iw / ih; - } else { - _iw = i % iw; - _ih = i / iw % ih; - _ic = i / iw / ih % ic; - _ib = i / iw / ih / ic; - } - - int64_t _oh0, _oh1, _ow0, _ow1; - _oh0 = _oh0_cached[_ih]; - _oh1 = _oh1_cached[_ih]; - _ow0 = _ow0_cached[_iw]; - _ow1 = _ow1_cached[_iw]; - int64_t _ob = _ib; - int64_t _oc = _ic; - - accscalar_t gx = 0; - accscalar_t _ikh, _ikw; - for (int _oh = _oh0; _oh < _oh1; _oh++) { - _ikh = _ikh_cached[_oh]; - for (int _ow = _ow0; _ow < _ow1; _ow++) { - _ikw = _ikw_cached[_ow]; - gx += gyacc[_ob][_oc][_oh][_ow] * _ikh * _ikw; - } - } - - const auto store = [](PackedTensorAccessor64 gxacc, - int64_t _ib, - int64_t _ic, - int64_t _ih, - int64_t _iw, - scalar_t res) { gxacc[_ib][_ic][_ih][_iw] = res; }; - store(gxacc, _ib, _ic, _ih, _iw, (scalar_t)gx); - } - } - - void sycl_ker_config_convention(sycl::handler& cgh) { - _oh0_cached = sycl_local_acc_t(ih, cgh); - _oh1_cached = sycl_local_acc_t(ih, cgh); - _ow0_cached = sycl_local_acc_t(iw, cgh); - _ow1_cached = sycl_local_acc_t(iw, cgh); - _ikh_cached = sycl_local_acc_t(oh, cgh); - _ikw_cached = sycl_local_acc_t(ow, cgh); - } - - AdaptiveAvgPool2dBwdSLMKernelFunctor( - PackedTensorAccessor64 gyacc_, - PackedTensorAccessor64 gxacc_) - : gyacc(gyacc_), gxacc(gxacc_) { - ib = gxacc.size(0); - ic = gxacc.size(1); - ih = gxacc.size(2); - iw = gxacc.size(3); - oh = gyacc.size(2); - ow = gyacc.size(3); - - numel = ib * ic * ih * iw; - int total_item = std::min(numel, syclMaxWorkItemsPerTile()); - - local_range = syclMaxWorkGroupSize(); - global_range = total_item < local_range - ? local_range - : (total_item / local_range) * local_range; - } - - sycl::range<1> glb_range() { - return sycl::range<1>(global_range); - } - - sycl::range<1> loc_range() { - return sycl::range<1>(local_range); - } - - private: - int ib; - int ic; - int ih; - int iw; - int oh; - int ow; - int64_t numel; - int local_range; - int global_range; - PackedTensorAccessor64 gyacc; - PackedTensorAccessor64 gxacc; - sycl_local_acc_t _oh0_cached; - sycl_local_acc_t _oh1_cached; - sycl_local_acc_t _ow0_cached; - sycl_local_acc_t _ow1_cached; - sycl_local_acc_t _ikh_cached; - sycl_local_acc_t _ikw_cached; -}; - -void adaptive_avg_pool2d_backward_out_kernel( - Tensor& gradInput, - const Tensor& gradOutput, - const Tensor& input) { - TensorArg grad_input_arg{gradInput, "gradInput", 1}, - grad_output_arg{gradOutput, "gradOutput", 2}, - input_arg{input, "input", 3}; - adaptive_pool_empty_output_check(gradOutput, "adaptive_avg_pool2d_backward"); - checkAllSameGPU(__func__, {grad_input_arg, grad_output_arg, input_arg}); - - TORCH_CHECK( - (input.ndimension() == 3 || input.ndimension() == 4), - "non-empty 3D or 4D (batch mode) tensor expected for input"); - - auto outputHeight = gradOutput.size(-2); - auto outputWidth = gradOutput.size(-1); - - const auto nInputPlane = input.size(-3); - const auto inputHeight = input.size(-2); - const auto inputWidth = input.size(-1); - - int dH = std::floor((float)2 * inputHeight / outputHeight) - - (inputHeight / outputHeight); - int dW = std::floor((float)2 * inputWidth / outputWidth) - - (inputWidth / outputWidth); - std::vector stride_vec = {dH, dW}; - - int kH = std::ceil((float)2 * inputHeight / outputHeight) - - (inputHeight / outputHeight); - int kW = std::ceil((float)2 * inputWidth / outputWidth) - - (inputWidth / outputWidth); - std::vector kernel_size_vec = {kH, kW}; - - int padH = (dH * (outputHeight - 1) + kH - inputHeight) / 2; - int padW = (dW * (outputWidth - 1) + kW - inputWidth) / 2; - std::vector padding_vec = {padH, padW}; - - bool is_3d = gradOutput.ndimension() == 3; - if (is_3d) { - gradOutput.resize_({1, nInputPlane, outputHeight, outputWidth}); - gradInput.resize_({1, nInputPlane, inputHeight, inputWidth}); - } - - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::BFloat16, - at::ScalarType::Half, - gradOutput.scalar_type(), - "adaptive_avg_pool2d_backward_xpu", - [&]() { - using accscalar_t = acc_type; - auto gyacc = gradOutput.packed_accessor64(); - auto gxacc = gradInput.packed_accessor64(); - - int64_t ohw01_shared_size = - ((inputHeight + inputWidth) * 2) * sizeof(int); - int64_t ikhw_shared_size = - (outputHeight + outputWidth) * sizeof(accscalar_t); - bool using_shared = - syclLocalMemSize() >= ohw01_shared_size + ikhw_shared_size; - - auto& q = getCurrentSYCLQueue(); - if (is_smf_channels_last(gradOutput)) { - if (using_shared) { - AdaptiveAvgPool2dBwdSLMKernelFunctor - kfn(gyacc, gxacc); - sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); - } else { - AdaptiveAvgPool2dBwdKernelFunctor kfn( - gyacc, gxacc); - sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); - } - } else { - if (using_shared) { - AdaptiveAvgPool2dBwdSLMKernelFunctor - kfn(gyacc, gxacc); - sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); - } else { - AdaptiveAvgPool2dBwdKernelFunctor kfn( - gyacc, gxacc); - sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); - } - } - }); - - if (is_3d) { - gradOutput.resize_({nInputPlane, outputHeight, outputWidth}); - gradInput.resize_({nInputPlane, inputHeight, inputWidth}); - } -} - -} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernel.h b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernel.h deleted file mode 100644 index e56609add..000000000 --- a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernel.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -namespace at::native::xpu { - -void adaptive_avg_pool2d_backward_out_kernel( - Tensor& gradInput, - const Tensor& gradOutput, - const Tensor& input); - -} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp new file mode 100644 index 000000000..aacc66062 --- /dev/null +++ b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp @@ -0,0 +1,513 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at::native::xpu { + +using namespace at::xpu; + +template +struct AdaptiveAvgPool2dBwdKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + int64_t gi = item.get_global_linear_id(); + + for (int64_t i = gi; i < numel_; i += global_range_) { + int64_t _iw, _ih, _ic, _ib; + if constexpr (is_channels_last) { + _ic = i % ic_; + _iw = i / ic_ % iw_; + _ih = i / ic_ / iw_ % ih_; + _ib = i / ic_ / iw_ / ih_; + } else { + _iw = i % iw_; + _ih = i / iw_ % ih_; + _ic = i / iw_ / ih_ % ic_; + _ib = i / iw_ / ih_ / ic_; + } + + int64_t _oh0 = native::start_index(_ih, ih_, oh_); + int64_t _oh1 = native::end_index(_ih, ih_, oh_); + int64_t _ow0 = native::start_index(_iw, iw_, ow_); + int64_t _ow1 = native::end_index(_iw, iw_, ow_); + int64_t _ob = _ib; + int64_t _oc = _ic; + + opmath_t gx = 0; + opmath_t _ikh, _ikw; + for (int _oh = _oh0; _oh < _oh1; _oh++) { + _ikh = opmath_t(1.0) / + (opmath_t)(native::end_index(_oh, oh_, ih_) - native::start_index(_oh, oh_, ih_)); + for (int _ow = _ow0; _ow < _ow1; _ow++) { + _ikw = opmath_t(1.0) / + (opmath_t)(native::end_index(_ow, ow_, iw_) - native::start_index(_ow, ow_, iw_)); + gx += gyacc_[_ob][_oc][_oh][_ow] * _ikh * _ikw; + } + } + + const auto store = [](PackedTensorAccessor64 gxacc, + int64_t _ib, + int64_t _ic, + int64_t _ih, + int64_t _iw, + scalar_t res) { gxacc[_ib][_ic][_ih][_iw] = res; }; + store(gxacc_, _ib, _ic, _ih, _iw, (scalar_t)gx); + } + } + + AdaptiveAvgPool2dBwdKernelFunctor( + PackedTensorAccessor64 gyacc, + PackedTensorAccessor64 gxacc) + : gyacc_(gyacc), gxacc_(gxacc) { + ib_ = gxacc_.size(0); + ic_ = gxacc_.size(1); + ih_ = gxacc_.size(2); + iw_ = gxacc_.size(3); + oh_ = gyacc_.size(2); + ow_ = gyacc_.size(3); + + numel_ = ib_ * ic_ * ih_ * iw_; + int total_item = std::min(numel_, syclMaxWorkItemsPerTile()); + local_range_ = syclMaxWorkItemsPerEU(); + global_range_ = total_item < local_range_ + ? local_range_ + : (total_item / local_range_) * local_range_; + } + + sycl::range<1> glb_range() { + return sycl::range<1>(global_range_); + } + + sycl::range<1> loc_range() { + return sycl::range<1>(local_range_); + } + + private: + int ib_; + int ic_; + int ih_; + int iw_; + int oh_; + int ow_; + int64_t numel_; + int global_range_; + int local_range_; + PackedTensorAccessor64 gyacc_; + PackedTensorAccessor64 gxacc_; +}; + +template +struct AdaptiveAvgPool2dBwdSLMKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<1> item) const { + int64_t gi = item.get_global_linear_id(); + int64_t li = item.get_local_id(0); + + // for-loop order: oh*ow->ih->iw + // reuse oh*ow(oh0, oh1, ow0, ow1), ih(ikh), iw(ikw) in inner loop. + for (int _ih = li; _ih < ih_; _ih += local_range_) { + _oh0_cached_[_ih] = (int)native::start_index(_ih, ih_, oh_); + _oh1_cached_[_ih] = (int)native::end_index(_ih, ih_, oh_); + } + for (int _iw = li; _iw < iw_; _iw += local_range_) { + _ow0_cached_[_iw] = (int)native::start_index(_iw, iw_, ow_); + _ow1_cached_[_iw] = (int)native::end_index(_iw, iw_, ow_); + } + for (int _oh = li; _oh < oh_; _oh += local_range_) { + _ikh_cached_[_oh] = opmath_t(1.0) / + (opmath_t)(native::end_index(_oh, oh_, ih_) - + native::start_index(_oh, oh_, ih_)); + } + for (int _ow = li; _ow < ow_; _ow += local_range_) { + _ikw_cached_[_ow] = opmath_t(1.0) / + (opmath_t)(native::end_index(_ow, ow_, iw_) - + native::start_index(_ow, ow_, iw_)); + } + + item.barrier(sycl_local_fence); + + for (int64_t i = gi; i < numel_; i += global_range_) { + int64_t _iw, _ih, _ic, _ib; + if constexpr (is_channels_last) { + _ic = i % ic_; + _iw = i / ic_ % iw_; + _ih = i / ic_ / iw_ % ih_; + _ib = i / ic_ / iw_ / ih_; + } else { + _iw = i % iw_; + _ih = i / iw_ % ih_; + _ic = i / iw_ / ih_ % ic_; + _ib = i / iw_ / ih_ / ic_; + } + + int64_t _oh0, _oh1, _ow0, _ow1; + _oh0 = _oh0_cached_[_ih]; + _oh1 = _oh1_cached_[_ih]; + _ow0 = _ow0_cached_[_iw]; + _ow1 = _ow1_cached_[_iw]; + int64_t _ob = _ib; + int64_t _oc = _ic; + + opmath_t gx = 0; + opmath_t _ikh, _ikw; + for (int _oh = _oh0; _oh < _oh1; _oh++) { + _ikh = _ikh_cached_[_oh]; + for (int _ow = _ow0; _ow < _ow1; _ow++) { + _ikw = _ikw_cached_[_ow]; + gx += gyacc_[_ob][_oc][_oh][_ow] * _ikh * _ikw; + } + } + + const auto store = [](PackedTensorAccessor64 gxacc, + int64_t _ib, + int64_t _ic, + int64_t _ih, + int64_t _iw, + scalar_t res) { gxacc[_ib][_ic][_ih][_iw] = res; }; + store(gxacc_, _ib, _ic, _ih, _iw, (scalar_t)gx); + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + _oh0_cached_ = sycl_local_acc_t(ih_, cgh); + _oh1_cached_ = sycl_local_acc_t(ih_, cgh); + _ow0_cached_ = sycl_local_acc_t(iw_, cgh); + _ow1_cached_ = sycl_local_acc_t(iw_, cgh); + _ikh_cached_ = sycl_local_acc_t(oh_, cgh); + _ikw_cached_ = sycl_local_acc_t(ow_, cgh); + } + + AdaptiveAvgPool2dBwdSLMKernelFunctor( + PackedTensorAccessor64 gyacc, + PackedTensorAccessor64 gxacc) + : gyacc_(gyacc), gxacc_(gxacc) { + ib_ = gxacc_.size(0); + ic_ = gxacc_.size(1); + ih_ = gxacc_.size(2); + iw_ = gxacc_.size(3); + oh_ = gyacc_.size(2); + ow_ = gyacc_.size(3); + + numel_ = ib_ * ic_ * ih_ * iw_; + int total_item = std::min(numel_, syclMaxWorkItemsPerTile()); + + local_range_ = syclMaxWorkGroupSize(); + global_range_ = total_item < local_range_ + ? local_range_ + : (total_item / local_range_) * local_range_; + } + + sycl::range<1> glb_range() { + return sycl::range<1>(global_range_); + } + + sycl::range<1> loc_range() { + return sycl::range<1>(local_range_); + } + + private: + int ib_; + int ic_; + int ih_; + int iw_; + int oh_; + int ow_; + int64_t numel_; + int local_range_; + int global_range_; + PackedTensorAccessor64 gyacc_; + PackedTensorAccessor64 gxacc_; + sycl_local_acc_t _oh0_cached_; + sycl_local_acc_t _oh1_cached_; + sycl_local_acc_t _ow0_cached_; + sycl_local_acc_t _ow1_cached_; + sycl_local_acc_t _ikh_cached_; + sycl_local_acc_t _ikw_cached_; +}; + +void adaptive_avg_pool2d_backward_kernel( + Tensor& grad_input, + const Tensor& grad_output_, + const Tensor& input_) { + Tensor input, grad_output; + if (input_.ndimension() == 3) { + input = input_.contiguous(); + grad_output = grad_output_.contiguous(); + grad_input = at::empty_like(input); + } else { + auto smf = input_.suggest_memory_format(); + input = input_.contiguous(smf); + grad_output = grad_output_.contiguous(smf); + grad_input = at::empty_like(input_, smf); + } + + auto outputHeight = grad_output.size(-2); + auto outputWidth = grad_output.size(-1); + + const auto nInputPlane = input.size(-3); + const auto inputHeight = input.size(-2); + const auto inputWidth = input.size(-1); + + int dH = std::floor((float)2 * inputHeight / outputHeight) - + (inputHeight / outputHeight); + int dW = std::floor((float)2 * inputWidth / outputWidth) - + (inputWidth / outputWidth); + std::vector stride_vec = {dH, dW}; + + int kH = std::ceil((float)2 * inputHeight / outputHeight) - + (inputHeight / outputHeight); + int kW = std::ceil((float)2 * inputWidth / outputWidth) - + (inputWidth / outputWidth); + std::vector kernel_size_vec = {kH, kW}; + + int padH = (dH * (outputHeight - 1) + kH - inputHeight) / 2; + int padW = (dW * (outputWidth - 1) + kW - inputWidth) / 2; + std::vector padding_vec = {padH, padW}; + + bool is_3d = grad_output.ndimension() == 3; + if (is_3d) { + grad_output.resize_({1, nInputPlane, outputHeight, outputWidth}); + grad_input.resize_({1, nInputPlane, inputHeight, inputWidth}); + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + grad_output.scalar_type(), + "adaptive_avg_pool2d_backward_xpu", + [&]() { + using opmath_t = at::opmath_type; + auto gyacc = grad_output.packed_accessor64(); + auto gxacc = grad_input.packed_accessor64(); + + int64_t ohw01_shared_size = + ((inputHeight + inputWidth) * 2) * sizeof(int); + int64_t ikhw_shared_size = + (outputHeight + outputWidth) * sizeof(opmath_t); + bool using_shared = + syclLocalMemSize() >= ohw01_shared_size + ikhw_shared_size; + + auto& q = getCurrentSYCLQueue(); + if (is_smf_channels_last(grad_output)) { + if (using_shared) { + AdaptiveAvgPool2dBwdSLMKernelFunctor kfn( + gyacc, gxacc); + sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); + } else { + AdaptiveAvgPool2dBwdKernelFunctor kfn( + gyacc, gxacc); + sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); + } + } else { + if (using_shared) { + AdaptiveAvgPool2dBwdSLMKernelFunctor kfn( + gyacc, gxacc); + sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); + } else { + AdaptiveAvgPool2dBwdKernelFunctor kfn( + gyacc, gxacc); + sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); + } + } + }); + + if (is_3d) { + grad_output.resize_({nInputPlane, outputHeight, outputWidth}); + grad_input.resize_({nInputPlane, inputHeight, inputWidth}); + } +} + +template +struct AdaptiveAvgPool2dKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + int64_t gi = item.get_global_linear_id(); + for (int64_t i = gi; i < numel_; i += global_range_) { + int64_t _ow, _oh, _oc, _ob; + if constexpr (is_channels_last) { + _oc = i % oc_; + _ow = i / oc_ % ow_; + _oh = i / oc_ / ow_ % oh_; + _ob = i / oc_ / ow_ / oh_; + } else { + _ow = i % ow_; + _oh = i / ow_ % oh_; + _oc = i / ow_ / oh_ % oc_; + _ob = i / ow_ / oh_ / oc_; + } + + int64_t _ih0 = native::start_index(_oh, oh_, ih_); + int64_t _ih1 = native::end_index(_oh, oh_, ih_); + int64_t _iw0 = native::start_index(_ow, ow_, iw_); + int64_t _iw1 = native::end_index(_ow, ow_, iw_); + int64_t kh = _ih1 - _ih0; + int64_t kw = _iw1 - _iw0; + int64_t _ib = _ob; + int64_t _ic = _oc; + + opmath_t sum = static_cast(0); + for (int _ih = _ih0; _ih < _ih1; _ih++) { + for (int _iw = _iw0; _iw < _iw1; _iw++) { + sum += opmath_t(input_[_ib][_ic][_ih][_iw]); + } + } + opmath_t avg = sum / kh / kw; + + const auto store = [](PackedTensorAccessor64 oacc, + int64_t _ob, + int64_t _oc, + int64_t _oh, + int64_t _ow, + scalar_t res) { oacc[_ob][_oc][_oh][_ow] = res; }; + store(output_, _ob, _oc, _oh, _ow, avg); + } + } + AdaptiveAvgPool2dKernelFunctor( + int ih, + int iw, + int ob, + int oc, + int oh, + int ow, + int64_t numel, + int global_range, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output) + : ih_(ih), + iw_(iw), + ob_(ob), + oc_(oc), + oh_(oh), + ow_(ow), + numel_(numel), + global_range_(global_range), + input_(input), + output_(output) {} + + private: + int ih_; + int iw_; + int ob_; + int oc_; + int oh_; + int ow_; + int64_t numel_; + int global_range_; + PackedTensorAccessor64 input_; + PackedTensorAccessor64 output_; +}; + +template +void launch_adaptive_avg_pool2d_kernel( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output) { + int ih = input.size(2); + int iw = input.size(3); + int ob = output.size(0); + int oc = output.size(1); + int oh = output.size(2); + int ow = output.size(3); + + int64_t numel = ob * oc * oh * ow; + int total_item = std::min(numel, syclMaxWorkItemsPerTile()); + int local_range = syclMaxWorkItemsPerEU(); + int global_range = total_item < local_range + ? local_range + : ((total_item + local_range - 1) / local_range) * local_range; + auto caller = + AdaptiveAvgPool2dKernelFunctor( + ih, iw, ob, oc, oh, ow, numel, global_range, input, output); + sycl_kernel_submit( + sycl::range<1>(global_range), + sycl::range<1>(local_range), + getCurrentSYCLQueue(), + caller); +} + +void adaptive_avg_pool2d_kernel( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) { + auto outputWidth = output_size[1]; + auto outputHeight = output_size[0]; + + if (!input.is_quantized() && outputWidth == 1 && outputHeight == 1) { + // in this case, adaptive pooling is just computing mean over hw + // dimensions, which can be done more efficiently + + output = input.mean({-1, -2}, /* keepdim = */ true); + if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) { + // assert ndim == 4, since ndim = 3 doesn't give channels_last + const int n = input.size(0); + const int c = input.size(1); + output.as_strided_({n, c, 1, 1}, {c, 1, c, c}); + } + return; + } + + /* sizes */ + const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; + const auto nInputPlane = input.size(-3); + const auto inputHeight = input.size(-2); + const auto inputWidth = input.size(-1); + Tensor input_; + if (input.ndimension() == 3) { + input_ = input.contiguous(); + output.resize_({nInputPlane, outputHeight, outputWidth}); + } else { + auto smf = input.suggest_memory_format(); + input_ = input.contiguous(smf); + output.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, smf); + } + if (output.numel() == 0) { + return; + } + int dH = std::floor((float)2 * inputHeight / outputHeight) - + (inputHeight / outputHeight); + int dW = std::floor((float)2 * inputWidth / outputWidth) - + (inputWidth / outputWidth); + std::vector stride_vec = {dH, dW}; + + int kH = std::ceil((float)2 * inputHeight / outputHeight) - + (inputHeight / outputHeight); + int kW = std::ceil((float)2 * inputWidth / outputWidth) - + (inputWidth / outputWidth); + std::vector kernel_size_vec = {kH, kW}; + + int padH = (dH * (outputHeight - 1) + kH - inputHeight) / 2; + int padW = (dW * (outputWidth - 1) + kW - inputWidth) / 2; + std::vector padding_vec = {padH, padW}; + + bool is_3d = input_.ndimension() == 3; + if (is_3d) { + input_.resize_({1, nInputPlane, inputHeight, inputWidth}); + output.resize_({1, nInputPlane, outputHeight, outputWidth}); + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + input_.scalar_type(), + "adaptive_avg_pool2d_xpu", + [&]() { + using opmath_t = at::opmath_type; + auto iacc = input_.packed_accessor64(); + auto oacc = output.packed_accessor64(); + if (is_smf_channels_last(output)) { + launch_adaptive_avg_pool2d_kernel( + iacc, oacc); + } else { + launch_adaptive_avg_pool2d_kernel( + iacc, oacc); + } + }); + + if (is_3d) { + input_.resize_({nInputPlane, inputHeight, inputWidth}); + output.resize_({nInputPlane, outputHeight, outputWidth}); + } +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.h b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.h new file mode 100644 index 000000000..9b6d9a046 --- /dev/null +++ b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void adaptive_avg_pool2d_backward_kernel( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input); + +void adaptive_avg_pool2d_kernel( + Tensor& output, + const Tensor& input, + IntArrayRef output_size); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp new file mode 100644 index 000000000..5c64816fb --- /dev/null +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -0,0 +1,3927 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +#define SIMD32 32 +#define SIMD16 16 + +// ========================== batch_norm utils ========================== + +ScalarType first_type() { + return ScalarType::Undefined; +} + +template +ScalarType first_type(const Tensor& arg, const Args&... parameters) { + return arg.defined() ? arg.scalar_type() : first_type(parameters...); +} + +// A transform is mixed type if the parameters are higher precision than the +// input +template +bool is_mixed_type(const Tensor& input, const Args&... parameters) { + const auto parameter_type = first_type(parameters...); + return ( + (parameter_type != ScalarType::Undefined) && + (parameter_type != input.scalar_type())); +} + +inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) { + return ( + self.is_contiguous(at::MemoryFormat::ChannelsLast) || + self.is_contiguous(at::MemoryFormat::ChannelsLast3d) || + (self.is_contiguous() && self.strides()[1] == 1)); +} + +enum class Impl { + Contiguous, + ChannelsLast, + General, +}; + +inline Impl batch_norm_choose_impl(const Tensor& self) { + if (!canUse32BitIndexMath(self)) { + return Impl::General; + } + + if (self.is_contiguous()) { + return self.strides()[1] == 1 ? Impl::ChannelsLast : Impl::Contiguous; + } + + if (self.is_contiguous(at::MemoryFormat::ChannelsLast)) { + return Impl::ChannelsLast; + } + + return Impl::General; +} + +inline Impl batch_norm_choose_impl(const Tensor& in1, const Tensor& in2) { + auto imp1 = batch_norm_choose_impl(in1); + if (imp1 == Impl::General) { + return imp1; + } + auto imp2 = batch_norm_choose_impl(in2); + return imp1 == imp2 ? imp1 : Impl::General; +} + +template < + typename scalar_t, + int64_t dim, + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +static GenericPackedTensorAccessor +get_packed_accessor(const Tensor& t, c10::string_view var_name) { + constexpr auto expect_type = c10::CppTypeToScalarType< + typename std::remove_const::type>::value; + const auto actual_type = t.scalar_type(); + TORCH_CHECK( + actual_type == expect_type, + "Expected ", + var_name, + " to have type ", + expect_type, + " but got ", + actual_type); + return t.generic_packed_accessor(); +} + +template < + typename scalar_t, + int64_t dim, + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +static GenericPackedTensorAccessor +packed_accessor_or_dummy(const Tensor& t, c10::string_view var_name) { + if (!t.defined()) { + const std::array zeros{{0}}; + return GenericPackedTensorAccessor( + nullptr, zeros.data(), zeros.data()); + } + return get_packed_accessor(t, var_name); +} + +struct InvStd { + template + inline T operator()(T var, double epsilon) const { + T invstd = 0.0f; + if (var != static_cast(0.0f) || epsilon != static_cast(0.0f)) { + invstd = static_cast(1.0f) / std::sqrt(var + static_cast(epsilon)); + } + return invstd; + } +}; + +struct Var { + template + inline T operator()(T var, double epsilon) const { + return var; + } +}; + +int get_max_group_size(int simd = SIMD32) { + // The max work group size required by batch_norm needs to ensure that the two + // subgroup reduces can obtain correct results. + int max_size = syclMaxWorkGroupSize(); + int shfl2_restricted_size = simd * simd; + return max_size > shfl2_restricted_size ? shfl2_restricted_size : max_size; +} + +int get_num_threads(int nelem, int restricted_simd = SIMD32) { + int max_size = get_max_group_size(restricted_simd); + int thread_sizes[5] = {32, 64, 128, 256, max_size}; + for (int i = 0; i < 5; ++i) { + if (nelem <= thread_sizes[i]) { + return thread_sizes[i]; + } + } + return max_size; +} + +int get_prefer_simd(int numPlane, int nHw) { + // decide SIMD: SIMD32 or SIMD16 + + auto dev_id = at::xpu::getDeviceIndexOfCurrentQueue(); + + auto* dev_prop = at::xpu::getDeviceProperties(dev_id); + auto sub_group_size = dev_prop->sub_group_sizes; + int simd = sub_group_size[1]; + if (simd <= SIMD16) + return simd; + + // if max supported simd >16 + if (nHw <= SIMD16) + return SIMD16; + if (simd >= SIMD32 && nHw <= SIMD32) + return SIMD32; + + int64_t target_tile_size = syclMaxWorkItemsPerTile(dev_id); + // for work group barrier perf + int64_t wg_size = syclMaxWorkItemsPerEU(dev_id); + if (simd == SIMD32) { + // when setting wg_size 256 can achieve high occupancy, use SIMD16 + if (wg_size * numPlane >= target_tile_size) + return SIMD16; + // for latency case + if (nHw <= 1024 && numPlane > 128 && SIMD16 * SIMD16 >= wg_size) { + return SIMD16; + } + } + return simd; +} + +template +struct Float2 { + accscalar_t v1, v2; + Float2() {} + + Float2(scalar_t v1, scalar_t v2) + : v1(static_cast(v1)), v2(static_cast(v2)) {} + Float2(int v) + : v1(static_cast(v)), v2(static_cast(v)) {} + Float2& operator+=(const Float2& a) { + v1 += a.v1; + v2 += a.v2; + return *this; + } + + friend Float2 operator+(Float2 a, const Float2& b) { + a += b; + return a; + } +}; + +template +struct GradOp { + GradOp(accscalar_t m, const PTA& i, const PTA& g) + : mean(m), input(i), grad_output(g) {} + Float2 operator()(int batch, int plane, int n) const { + accscalar_t g = grad_output[batch][plane][n]; + accscalar_t c = static_cast(input[batch][plane][n]) - mean; + return Float2(g, g * c); + } + const accscalar_t mean; + const PTA& input; + const PTA& grad_output; +}; + +template < + int SIMD, + typename accscalar_t, + typename reduce_op, + typename item_t, + typename local_shared_t> +static inline void group_reduce( + item_t item, + int sub_group_num, + accscalar_t& val, + accscalar_t init, + const local_shared_t& local_data, + reduce_op bin_op) { + auto sg = item.get_sub_group(); + uint32_t lane_id = sg.get_local_linear_id(); + uint32_t sg_id = sg.get_group_linear_id(); + + // dynamic get SIMD width result in big performance drop + // uint32_t SIMD = sg.get_local_range()[0]; +#pragma unroll + for (int i = 1; i < SIMD; i <<= 1) { + val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + } + if (sub_group_num == 1) { + if (lane_id == 0) { + local_data[0] = val; + } + item.barrier(sycl_local_fence); + val = local_data[0]; + + return; + } + + // reduce internal each subgroup, each subgroup will generate one result + // there are WGroupSize/subGroupSize elements after this step + if (lane_id == 0) { + local_data[sg_id] = val; + } + item.barrier(sycl_local_fence); + + // use one subgroup to reduce WGroupSize/subGroupSize elements + // into the final result + if (sg_id == 0) { + val = init; + if (lane_id < sub_group_num) { + val = accscalar_t(local_data[lane_id]); + } + for (int i = lane_id + SIMD; i < sub_group_num; i += SIMD) { + val = bin_op(val, static_cast(local_data[i])); + } +#pragma unroll + for (int i = 1; i < SIMD; i <<= 1) { + val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + if (i >= ((sub_group_num + 1) >> 1)) + break; + } + + // the 0th WI (the 0th WI in the 0th sub_group) generate the final + // result + if (lane_id == 0) { + local_data[0] = val; + } + } + + item.barrier(sycl_local_fence); + val = local_data[0]; +} + +template < + int SIMD, + typename scalar_t, + typename item_t, + typename Op, + typename PTA, + typename local_shared_t> +scalar_t plane_reduce( + item_t item, + Op grad_op, + PTA tensor, + int plane, + int sub_group_num, + const local_shared_t& shared) { + // first the reductions each thread does separately + scalar_t sum_value = 0; + for (int batch = item.get_local_id(0); batch < tensor.size(0); + batch += item.get_local_range(0)) { + for (int x = item.get_local_id(1); x < tensor.size(2); + x += item.get_local_range(1)) { + auto res = grad_op(batch, plane, x); + sum_value += res; + } + } + group_reduce( + item, + sub_group_num, + sum_value, + scalar_t(0), + shared, + [](scalar_t a, scalar_t b) { return a + b; }); + if (item.get_local_linear_id() == 0) { + shared[0] = sum_value; + } + item.barrier(sycl_local_fence); + // Everyone picks it up, should be broadcast into the whole grad_input + return shared[0]; +} + +inline int div_up(int a, int b) { + return (a + b - 1) / b; +} + +constexpr int ELEMENTS_PER_ITER = + 4; // enables concurrency within each thread to hide latency +constexpr int ELEMENTS_PER_WORK_ITEM = 16; + +std::tuple, sycl::range<2>> get_adaptive_launch_config( + const int reduction, + const int stride, + const bool coop_flag = false, + const int loops_per_item = 1) { + int max_wg_size = syclMaxWorkItemsPerEU(); + int group_x = std::min(last_pow2(stride), 32); + int group_y = std::min( + last_pow2(div_up(reduction, loops_per_item)), max_wg_size / group_x); + if (group_x * group_y != max_wg_size) { + group_x = std::min(last_pow2(stride), max_wg_size / group_y); + } + + int nwg_x = div_up(stride, group_x); + int nwg_y = std::min( + div_up(reduction, group_y * loops_per_item), + int(syclMaxWorkItemsPerTile()) / (nwg_x * group_x) / (group_y)); + nwg_y = std::max(nwg_y, 1); + + if (coop_flag) { + // it's not worth having a grid reduction if the reduction dimension is not + // big enough + nwg_y = nwg_y < 8 ? 1 : nwg_y; + } + + sycl::range<2> local_range(group_y, group_x); + sycl::range<2> global_range(nwg_y * group_y, nwg_x * group_x); + + return std::make_tuple(global_range, local_range); +} + +template +inline void welford_merge_element( + C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / std::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +// ========================== batch_norm_stats ========================== + +template < + int SIMD, + typename VarTransform, + typename input_scalar_t, + typename stat_scalar_t, + typename stat_accscalar_t, + typename index_t> +struct BatchNormCollectStatisticsKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<2> item) const { + int plane = item.get_group(1); + int tid = item.get_local_linear_id(); + + auto sg = item.get_sub_group(); + auto sg_lid = sg.get_local_linear_id(); + auto sg_id = sg.get_group_linear_id(); + + // Compute the mean and variance across (batch, x/y/z) + // this uses the Welford (in the for loop)/parallel algorithm (to sum + // across the group) + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + // and the parallel algorithm on the same page. + // We use two shuffles to reduce across the entire group. + + // first the reductions each thread does separately + stat_accscalar_t avg = 0; + stat_accscalar_t var_n = 0; + int n = 0; + for (int batch = item.get_local_id(0); batch < input_.size(0); + batch += item.get_local_range(0)) { + for (int x = item.get_local_id(1); x < input_.size(2); + x += item.get_local_range(1)) { + stat_accscalar_t v = input_[batch][plane][x]; + stat_accscalar_t d1 = v - avg; + n++; + avg += d1 / n; + var_n += d1 * (v - avg); + } + } + + // first subgroupSum to get one value per thread to + // one value per subgroup +#pragma unroll + for (int i = 1; i < SIMD; i <<= 1) { + stat_accscalar_t o_avg = sg.shuffle_xor(avg, i); + int o_n = sg.shuffle_xor(n, i); + stat_accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n); + var_n += sg.shuffle_xor(var_n, i) + + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // this writes each subgroups item into shared memory + if (sg_lid == 0) { + shared_n_[sg_id] = n; + shared_avg_var_[sg_id * 2] = avg; + shared_avg_var_[sg_id * 2 + 1] = var_n; + } + item.barrier(sycl_local_fence); + // now have a second subgroupSum to reduce the intermediate values + // from shared memory to a single number. The very first + // thread writes it to shared memory. + int num_sg = item.get_local_range(1) * item.get_local_range(0) / SIMD; + if (tid < num_sg) { + n = shared_n_[tid]; + avg = shared_avg_var_[2 * tid]; + var_n = shared_avg_var_[2 * tid + 1]; + } else { + n = 0; + avg = stat_accscalar_t(0); + var_n = stat_accscalar_t(0); + } +#pragma unroll + for (int i = 1; i < SIMD; i <<= 1) { + stat_accscalar_t o_avg = sg.shuffle_xor(avg, i); + int o_n = sg.shuffle_xor(n, i); + stat_accscalar_t factor = 1.0f / fmaxf(1.0f, n + o_n); + var_n += sg.shuffle_xor(var_n, i) + + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } + + // Save the mean, variance, and moving averages + auto save_mean = save_mean_; + auto save_transformed_var = save_transformed_var_; + if (tid == 0) { + if (save_mean_.data() != NULL) { + save_mean[plane] = avg; + } + if (save_transformed_var_.data() != NULL) { + save_transformed_var[plane] = + VarTransform{}(var_n / (input_.size(0) * input_.size(2)), epsilon_); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_n_ = sycl_local_acc_t(sycl::range<1>{(size_t)SIMD}, cgh); + shared_avg_var_ = sycl_local_acc_t( + sycl::range<1>{(size_t)SIMD * 2 * 2}, cgh); + } + + BatchNormCollectStatisticsKernelFunctor( + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + RestrictPtrTraits, + index_t> input, + const stat_accscalar_t epsilon, + const stat_accscalar_t momentum, + GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + RestrictPtrTraits, + index_t> save_mean, + GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + RestrictPtrTraits, + index_t> save_transformed_var) + : input_(input), + epsilon_(epsilon), + momentum_(momentum), + save_mean_(save_mean), + save_transformed_var_(save_transformed_var) {} + + private: + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + RestrictPtrTraits, + index_t> + input_; + const stat_accscalar_t epsilon_; + const stat_accscalar_t momentum_; + GenericPackedTensorAccessor + save_mean_; + GenericPackedTensorAccessor + save_transformed_var_; + sycl_local_acc_t shared_n_; + sycl_local_acc_t shared_avg_var_; +}; + +template +void batch_norm_stats_template( + const Tensor& out_mean, + const Tensor& out_invstd, + const Tensor& input_, + double epsilon) { + using accscalar_t = at::acc_type; + int64_t n_input = input_.size(1); + Tensor dummy_mean_; + Tensor dummy_var_; + auto input_reshaped = input_.reshape( + {input_.size(0), + input_.size(1), + -1}); // internally we merge the feature dimensions + + at::native::resize_output(out_mean, {n_input}); + at::native::resize_output(out_invstd, {n_input}); + + auto input = + get_packed_accessor( + input_reshaped, "input"); + + TORCH_INTERNAL_ASSERT( + out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT( + out_mean.dim() == 1 && out_mean.is_contiguous() && out_mean.sizes()[0]); + + auto mean = + packed_accessor_or_dummy( + out_mean, "out_mean"); + auto invstd = + packed_accessor_or_dummy( + out_invstd, "out_invstd"); + + auto& queue = getCurrentSYCLQueue(); + int simd = get_prefer_simd(input.size(1), input.size(0) * input.size(2)); + int max_group_size = get_max_group_size(simd); + int tf = get_num_threads(input.size(2), simd); + int64_t work_group_size_x = tf; + int64_t work_group_size_y = std::max(1, max_group_size / tf); + int64_t global_size_x = input.size(1) * work_group_size_x; + int64_t global_size_y = 1 * work_group_size_y; + + if (simd == SIMD32) { + auto caller = BatchNormCollectStatisticsKernelFunctor< + SIMD32, + VarTransform, + scalar_t, + scalar_t, + accscalar_t, + index_t>(input, epsilon, 0.0, mean, invstd); + sycl_kernel_submit( + sycl::range<2>(global_size_y, global_size_x), + sycl::range<2>(work_group_size_y, work_group_size_x), + queue, + caller); + } else { + auto caller = BatchNormCollectStatisticsKernelFunctor< + SIMD16, + VarTransform, + scalar_t, + scalar_t, + accscalar_t, + index_t>(input, epsilon, 0.0, mean, invstd); + sycl_kernel_submit( + sycl::range<2>(global_size_y, global_size_x), + sycl::range<2>(work_group_size_y, work_group_size_x), + queue, + caller); + } +} + +template < + typename scalar_t, + typename accscalar_t, + typename vec_t, + typename vec_y, + int vec_size, + bool two_pass_reduce> +struct BatchNormReduceSumChannelsLastKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + // int plane = item.get_group(0); + // int tid = item.get_local_linear_id(); + auto sg = item.get_sub_group(); + + // offset along m dimension + int m_offset = item.get_global_id(0); + int c_offset_base = item.get_global_id(1) * vec_size; + + int thread_idx_y = item.get_local_id(0); + // int thread_idx_x = item.get_local_id(1); + int group_idx_y = item.get_group(0); + // int group_idx_x = item.get_group(1); + + int address_base = m_offset * stride_ + c_offset_base; + int inner_loop_stride = global_range_y_; + int address_increment = inner_loop_stride * stride_; + + accscalar_t x_sum[vec_size] = {0.0f}; + accscalar_t x_sq_sum[vec_size] = {0.0f}; + // thread reduction + for (int i = 0; i < loop_count_; i++) { + vec_t x_math_vec = *(reinterpret_cast(input_ptr_ + address_base)); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + auto c_offset = c_offset_base + j; + + if (c_offset < stride_ && m_offset < reduction_size_) { + // scalar_t arr = input_ptr_[address_base + j]; + auto x_math = x_math_vec[j]; + x_sum[j] += x_math; + x_sq_sum[j] += x_math * x_math; + } + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + +#pragma unroll + for (int j = 0; j < vec_size; j++) { + vec_y value; + value[0] = x_sum[j]; + value[1] = x_sq_sum[j]; + + value = group_y_reduce( + item, shared_, value, [](accscalar_t a, accscalar_t b) { + return a + b; + }); + + x_sum[j] = value[0]; + x_sq_sum[j] = value[1]; + + item.barrier(sycl_local_fence); + } + +#pragma unroll + for (int j = 0; j < vec_size; j++) { + auto c_offset = c_offset_base + j; + // global_reduciton + if (thread_idx_y == 0 && c_offset < stride_) { + if constexpr (two_pass_reduce) { + // write to temp[c][group_idx_y] + // int offset = c_offset * group_num_y_ + group_idx_y; + temp_sum_ptr_[c_offset * group_num_y_ + group_idx_y] = x_sum[j]; + temp_sum_sq_ptr_[c_offset * group_num_y_ + group_idx_y] = x_sq_sum[j]; + } else { + out_mean_ptr_[c_offset] = x_sum[j]; + out_invstd_ptr_[c_offset] = x_sq_sum[j]; + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(sycl::range<1>{(size_t)wg_size_}, cgh); + } + + BatchNormReduceSumChannelsLastKernelFunctor( + const int reduction_size, + const int stride, + int global_range_y, + int local_range_y, + int group_num_x, + int group_num_y, + accscalar_t* temp_sum_ptr, + accscalar_t* temp_sum_sq_ptr, + int wg_size, + scalar_t* input_ptr, + accscalar_t* out_mean_ptr, + accscalar_t* out_invstd_ptr, + int loop_count) + : reduction_size_(reduction_size), + stride_(stride), + global_range_y_(global_range_y), + local_range_y_(local_range_y), + group_num_x_(group_num_x), + group_num_y_(group_num_y), + temp_sum_ptr_(temp_sum_ptr), + temp_sum_sq_ptr_(temp_sum_sq_ptr), + wg_size_(wg_size), + input_ptr_(input_ptr), + out_mean_ptr_(out_mean_ptr), + out_invstd_ptr_(out_invstd_ptr), + loop_count_(loop_count) {} + + private: + const int reduction_size_; + const int stride_; + int global_range_y_; + int local_range_y_; + int group_num_x_; + int group_num_y_; + accscalar_t* temp_sum_ptr_; + accscalar_t* temp_sum_sq_ptr_; + int wg_size_; + scalar_t* input_ptr_; + accscalar_t* out_mean_ptr_; + accscalar_t* out_invstd_ptr_; + int loop_count_; + sycl_local_acc_t shared_; +}; + +template +struct BatchNormReduceSumChannelsLastTwoPassKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + auto local_id = item.get_local_linear_id(); + // auto global_id = item.get_global_linear_id(); + auto c_offset = item.get_group_linear_id(); + + accscalar_t temp_sum_val = 0.0f; + accscalar_t temp_sum_sq_val = 0.0f; + for (int i = local_id; i < group_num_y_; i += wg_size_) { + int offset = c_offset * group_num_y_ + i; + temp_sum_val += temp_sum_ptr_[offset]; + temp_sum_sq_val += temp_sum_sq_ptr_[offset]; + } + auto total_sum = sycl::reduce_over_group( + item.get_group(), temp_sum_val, sycl::plus()); + auto total_sum_sq = sycl::reduce_over_group( + item.get_group(), temp_sum_sq_val, sycl::plus()); + if (local_id == 0) { + out_mean_ptr_[c_offset] = total_sum; + out_invstd_ptr_[c_offset] = total_sum_sq; + } + } + BatchNormReduceSumChannelsLastTwoPassKernelFunctor( + int group_num_y, + accscalar_t* temp_sum_ptr, + accscalar_t* temp_sum_sq_ptr, + int wg_size, + accscalar_t* out_mean_ptr, + accscalar_t* out_invstd_ptr) + : group_num_y_(group_num_y), + temp_sum_ptr_(temp_sum_ptr), + temp_sum_sq_ptr_(temp_sum_sq_ptr), + wg_size_(wg_size), + out_mean_ptr_(out_mean_ptr), + out_invstd_ptr_(out_invstd_ptr) {} + + private: + int group_num_y_; + accscalar_t* temp_sum_ptr_; + accscalar_t* temp_sum_sq_ptr_; + int wg_size_; + accscalar_t* out_mean_ptr_; + accscalar_t* out_invstd_ptr_; +}; + +template +inline void welford_merge_group_vertical( + item_t item, + C& count, + T& mean, + T& m2n, + CACC& shmem_count, + TACC& shmem_mean, + TACC& shmem_m2n) { + // write to shared memory + auto address_base = item.get_local_linear_id(); + +#pragma unroll + for (int offset = item.get_local_range(0) / 2; offset > 0; offset >>= 1) { + if (item.get_local_id(0) < offset * 2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + item.barrier(sycl_local_fence); + if (item.get_local_id(0) < offset && + item.get_local_id(0) + offset < item.get_local_range(0)) { + auto address = address_base + offset * item.get_local_range(1); + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; + + welford_merge_element(count, mean, m2n, count_new, mean_new, m2n_new); + } + } +} + +template < + typename VarTransform, + typename scalar_t, + typename accscalar_t, + int PARALLEL_LOADS> +struct BatchNormCollectStatisticsChannelsLastKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + accscalar_t x_mean[PARALLEL_LOADS]; + accscalar_t m_2_n[PARALLEL_LOADS]; + int count[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + x_mean[i] = accscalar_t(0); + m_2_n[i] = accscalar_t(0); + count[i] = accscalar_t(0); + } + + // loop along m dimension + int inner_loop_stride = item.get_local_range(0) * item.get_group_range(0); + + // offset along m dimension + int m_offset = item.get_global_id(0); + int c_offset = item.get_global_id(1); + + int loop_count = + 1 + (reduction_size_ - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride_ + c_offset; + int address_increment = inner_loop_stride * stride_; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_math[PARALLEL_LOADS]; + accscalar_t x_count_inv[PARALLEL_LOADS]; + accscalar_t is_valid[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride_ && m_offset < reduction_size_) { + x_math[j] = input_[address_base]; + count[j]++; + x_count_inv[j] = accscalar_t(1) / count[j]; + is_valid[j] = accscalar_t(1); + } else { + x_math[j] = accscalar_t(0); + x_count_inv[j] = accscalar_t(0); + is_valid[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate mean/m2n with welford +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + accscalar_t delta0 = x_math[j] - x_mean[j]; + x_mean[j] += delta0 * x_count_inv[j]; + accscalar_t delta1 = x_math[j] - x_mean[j]; + m_2_n[j] += delta0 * delta1 * is_valid[j]; + } + } + + // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + welford_merge_element( + count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); + } + + // release x_mean / m_2_n + auto mean_th = x_mean[0]; + auto m2_th = m_2_n[0]; + auto count_th = count[0]; + + welford_merge_group_vertical( + item, count_th, mean_th, m2_th, shmem_count_, shmem_mean_, shmem_m2n_); + + if (item.get_group_range(0) > 1) { + volatile accscalar_t* staging_mean = staging_data_; + volatile accscalar_t* staging_m2n = + &staging_data_[stride_ * item.get_group_range(0)]; + volatile int* staging_count = reinterpret_cast( + &staging_m2n[stride_ * item.get_group_range(0)]); + + address_base = c_offset + item.get_group(0) * stride_; + // write data to staging_data; + if (item.get_local_id(0) == 0 && c_offset < stride_) { + staging_mean[address_base] = mean_th; + staging_m2n[address_base] = m2_th; + staging_count[address_base] = count_th; + } + + item.barrier(sycl_local_fence); + + // mark group done + if (item.get_local_linear_id() == 0) { + sycl_atomic_ref_rlx_dev_global_t count( + semaphores_[item.get_group(1)]); + int old = count.fetch_add( + 1, sycl_mem_odr_acq_rel + /* , default memory scope is device */); + is_last_group_done_[0] = (old == (item.get_group_range(0) - 1)); + } + + item.barrier(sycl_local_fence); + + // check that all data is now available in global memory + if (is_last_group_done_[0]) { + count_th = 0; + mean_th = accscalar_t(0.0); + m2_th = accscalar_t(0.0); + + for (int y = item.get_local_id(0); y < item.get_group_range(0); + y += item.get_local_range(0)) { + address_base = c_offset + y * stride_; + int count_new = c_offset < stride_ ? staging_count[address_base] : 0; + accscalar_t mean_new = c_offset < stride_ ? staging_mean[address_base] + : accscalar_t(0.0); + accscalar_t m2n_new = + c_offset < stride_ ? staging_m2n[address_base] : accscalar_t(0.0); + + welford_merge_element( + count_th, mean_th, m2_th, count_new, mean_new, m2n_new); + } + + welford_merge_group_vertical( + item, + count_th, + mean_th, + m2_th, + shmem_count_, + shmem_mean_, + shmem_m2n_); + if (item.get_local_id(0) == 0 && c_offset < stride_) { + out_mean_[c_offset] = static_cast(mean_th); + out_invstd_[c_offset] = VarTransform{}(m2_th / count_th, epsilon_); + } + } + } else { + if (item.get_group(0) == 0 && item.get_local_id(0) == 0 && + c_offset < stride_) { + out_mean_[c_offset] = static_cast(mean_th); + out_invstd_[c_offset] = VarTransform{}(m2_th / count_th, epsilon_); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + size_t max_wg_sz = syclMaxWorkGroupSize(); + shmem_mean_ = sycl_local_acc_t(sycl::range<1>{max_wg_sz}, cgh); + shmem_m2n_ = sycl_local_acc_t(sycl::range<1>{max_wg_sz}, cgh); + shmem_count_ = sycl_local_acc_t(sycl::range<1>{max_wg_sz}, cgh); + is_last_group_done_ = sycl_local_acc_t(sycl::range<1>{1}, cgh); + } + + BatchNormCollectStatisticsChannelsLastKernelFunctor( + const scalar_t* __restrict__ input, + accscalar_t* __restrict__ out_mean, + accscalar_t* __restrict__ out_invstd, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride, + accscalar_t epsilon) + : input_(input), + out_mean_(out_mean), + out_invstd_(out_invstd), + staging_data_(staging_data), + semaphores_(semaphores), + reduction_size_(reduction_size), + stride_(stride), + epsilon_(epsilon) {} + + private: + const scalar_t* __restrict__ input_; + accscalar_t* __restrict__ out_mean_; + accscalar_t* __restrict__ out_invstd_; + volatile accscalar_t* staging_data_; + int* semaphores_; + const int reduction_size_; + const int stride_; + accscalar_t epsilon_; + sycl_local_acc_t shmem_mean_; + sycl_local_acc_t shmem_m2n_; + sycl_local_acc_t shmem_count_; + sycl_local_acc_t is_last_group_done_; +}; + +template +void batch_norm_stats_channels_last_template( + Tensor& out_mean, + Tensor& out_invstd, + const Tensor& input, + double epsilon) { + using accscalar_t = acc_type; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::native::resize_output(out_mean, {stride}); + at::native::resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT( + out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT( + out_mean.dim() == 1 && out_mean.is_contiguous() && out_mean.sizes()[0]); + + auto config = get_adaptive_launch_config( + reduction_size, stride, true, ELEMENTS_PER_WORK_ITEM); + auto global_range = std::get<0>(config); + auto local_range = std::get<1>(config); + + at::Tensor staging_data; + at::Tensor semaphores; + auto wg_size_y = local_range[0]; + auto wg_size_x = local_range[1]; + auto nwg_y = global_range[0] / wg_size_y; + auto nwg_x = global_range[1] / wg_size_x; + if (nwg_y > 1) { + staging_data = at::empty({(long)(4 * stride * nwg_y)}, out_mean.options()); + semaphores = at::zeros({(long)nwg_x}, input.options().dtype(at::kInt)); + } + accscalar_t* staging_data_ptr = + nwg_y > 1 ? staging_data.mutable_data_ptr() : nullptr; + int* semaphores_ptr = + nwg_y > 1 ? semaphores.mutable_data_ptr() : nullptr; + + auto caller = BatchNormCollectStatisticsChannelsLastKernelFunctor< + VarTransform, + scalar_t, + accscalar_t, + ELEMENTS_PER_ITER>( + input.const_data_ptr(), + out_mean.mutable_data_ptr(), + out_invstd.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride, + epsilon); + sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), caller); +} + +std::tuple batch_norm_stats_kernel( + const Tensor& self, + double epsilon) { + auto options = + self.options().dtype(at::toAccumulateType(self.scalar_type(), true)); + auto n_channels = self.size(1); + auto save_mean = at::empty({n_channels}, options); + auto save_invstd = at::empty({n_channels}, options); + + bool use_channels_last_kernel = batch_norm_use_channels_last_kernels(self); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "batch_norm_stats_xpu", + [&] { + if (canUse32BitIndexMath(self)) { + if (use_channels_last_kernel) { + batch_norm_stats_channels_last_template( + save_mean, save_invstd, self, epsilon); + } else { + batch_norm_stats_template( + save_mean, save_invstd, self, epsilon); + } + } else { + batch_norm_stats_template( + save_mean, save_invstd, self, epsilon); + } + }); + return std::tuple(save_mean, save_invstd); +} + +// ========================== batch_norm_elemt ========================== + +template < + typename input_scalar_t, + typename stat_scalar_t, + typename stat_accscalar_t, + bool train, + typename index_t> +struct BatchNormTransformInputKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + index_t plane = item.get_group(1); + + if (plane >= input_.size(1)) { + return; + } + + stat_accscalar_t gamma = weight_.size(0) > 0 + ? static_cast(weight_[plane]) + : static_cast(1); + stat_accscalar_t beta = bias_.size(0) > 0 + ? static_cast(bias_[plane]) + : static_cast(0); + stat_accscalar_t mean = static_cast(mean_[plane]); + stat_accscalar_t invstd; + if constexpr (train) { + invstd = var_or_invstd_[plane]; + } else { + invstd = + static_cast(1) / + device_sqrt( + static_cast(var_or_invstd_[plane]) + epsilon_); + } + + index_t bs = input_.size(0); + index_t fs = input_.size(2); + + index_t bstep = item.get_local_range(0) * item.get_group_range(0); + for (index_t batch = item.get_global_id(0); batch < bs; batch += bstep) { + auto o = output_[batch][plane]; + auto i = input_[batch][plane]; + for (index_t feature = item.get_local_id(1); feature < fs; + feature += item.get_local_range(1)) { + o[feature] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); + } + } + } + + BatchNormTransformInputKernelFunctor( + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + RestrictPtrTraits, + index_t> input, + GenericPackedTensorAccessor + output, + const GenericPackedTensorAccessor< + typename std::conditional:: + type, + 1, + RestrictPtrTraits, + index_t> mean, + const GenericPackedTensorAccessor< + typename std::conditional:: + type, + 1, + RestrictPtrTraits, + index_t> var_or_invstd, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> weight, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> bias, + stat_accscalar_t epsilon) + : input_(input), + output_(output), + mean_(mean), + var_or_invstd_(var_or_invstd), + weight_(weight), + bias_(bias), + epsilon_(epsilon) {} + + private: + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + RestrictPtrTraits, + index_t> + input_; + GenericPackedTensorAccessor + output_; + const GenericPackedTensorAccessor< + typename std::conditional::type, + 1, + RestrictPtrTraits, + index_t> + mean_; + const GenericPackedTensorAccessor< + typename std::conditional::type, + 1, + RestrictPtrTraits, + index_t> + var_or_invstd_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> + weight_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t> + bias_; + stat_accscalar_t epsilon_; +}; + +template +void batch_norm_elemt_template( + const Tensor& output_, + const Tensor& input_, + const Tensor& weight_, + const Tensor& bias_, + const Tensor& mean_, + const Tensor& invstd_) { + using stat_accscalar_t = acc_type; + auto input_reshaped = input_.reshape( + {input_.size(0), + input_.size(1), + -1}); // internally we merge the feature dimensions + auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1}); + + auto input = + get_packed_accessor( + input_reshaped, "input"); + auto output = + get_packed_accessor( + output_reshaped, "output"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t>(weight_, "weight"); + auto bias = packed_accessor_or_dummy< + const stat_scalar_t, + 1, + RestrictPtrTraits, + index_t>(bias_, "bias"); + auto mean = + packed_accessor_or_dummy( + mean_, "mean"); + auto invstd = + packed_accessor_or_dummy( + invstd_, "invstd"); + auto& queue = getCurrentSYCLQueue(); + + // NOTE: We use transform_input_kernel in training mode, which ignores + // epsilon + const double dummy_epsilon = 1e-5; + + int tf = std::max( + get_num_threads(input.size(2) / 4), + std::min(get_num_threads(input.size(2)), 64)); + int tb = std::max(64 / tf, 1); + sycl::range<2> local_range(tb, tf); + int nwg_x = input.size(1); + int nwg_y = std::max( + 1, + std::min( + (256 * 1024) / input.size(1), (input.size(0) + tb - 1) / tb)); + nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); + sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); + + auto caller = BatchNormTransformInputKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + + sycl_kernel_submit(global_range, local_range, queue, caller); +} + +template < + typename scalar_t, + typename accscalar_t, + typename layerscalar_t, + int PARALLEL_LOADS> +struct BatchNormTransformInputChannelsLastKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = item.get_local_range(0) * item.get_group_range(0); + + // offset along m dimension + int m_offset = item.get_global_id(0); + int c_offset = item.get_global_id(1); + + if (c_offset >= stride_ || m_offset >= reduction_size_) { + return; + } + + auto m_c = mean_[c_offset]; + auto inv_std_c = static_cast(inv_std_[c_offset]); + auto w_c = weight_ == nullptr ? accscalar_t(1.0) + : static_cast(weight_[c_offset]); + auto s_c = shift_ == nullptr ? accscalar_t(0.0) + : static_cast(shift_[c_offset]); + + int loop_count = + 1 + (reduction_size_ - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride_ + c_offset; + int address_increment = inner_loop_stride * stride_; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride_ && m_offset < reduction_size_) { + auto tmp = w_c * + (static_cast(input_[address_base]) - m_c) * + inv_std_c + + s_c; + if (z_ != nullptr) { + tmp += z_[address_base]; + } + out_[address_base] = + (fuse_relu_ && tmp <= accscalar_t(0.0) + ? scalar_t(0.0) + : static_cast(tmp)); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } + } + + BatchNormTransformInputChannelsLastKernelFunctor( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ z, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const layerscalar_t* __restrict__ shift, + scalar_t* __restrict__ out, + const int reduction_size, + const int stride, + const bool fuse_relu) + : input_(input), + z_(z), + mean_(mean), + inv_std_(inv_std), + weight_(weight), + shift_(shift), + out_(out), + reduction_size_(reduction_size), + stride_(stride), + fuse_relu_(fuse_relu) {} + + private: + const scalar_t* __restrict__ input_; + const scalar_t* __restrict__ z_; + const accscalar_t* __restrict__ mean_; + const accscalar_t* __restrict__ inv_std_; + const layerscalar_t* __restrict__ weight_; + const layerscalar_t* __restrict__ shift_; + scalar_t* __restrict__ out_; + const int reduction_size_; + const int stride_; + const bool fuse_relu_; +}; + +void batch_norm_elemt_channels_last_template( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& shift, // bias of BN + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::optional& z = c10::nullopt, // bias after BN + const bool fuse_relu = false) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + auto config = get_adaptive_launch_config( + reduction_size, stride, false, ELEMENTS_PER_WORK_ITEM); + auto global_range = std::get<0>(config); + auto local_range = std::get<1>(config); + auto& queue = getCurrentSYCLQueue(); + const auto second_dtype = weight.defined() + ? weight.scalar_type() + : (shift.defined() ? shift.scalar_type() : input.scalar_type()); + + if (input.scalar_type() != second_dtype) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward_xpu", [&] { + using accscalar_t = at::acc_type; + auto caller = BatchNormTransformInputChannelsLastKernelFunctor< + scalar_t, + accscalar_t, + accscalar_t, + ELEMENTS_PER_ITER>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr() : nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + sycl_kernel_submit(global_range, local_range, queue, caller); + }); + } else { + if (weight.defined()) { + TORCH_CHECK( + input.scalar_type() == weight.scalar_type(), + "batchnorm_forward: input.scalar_type() ", + input.scalar_type(), + " is not supported with weight.scalar_type() ", + weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "batchnorm_forward_xpu", [&] { + using accscalar_t = at::acc_type; + auto caller = BatchNormTransformInputChannelsLastKernelFunctor< + scalar_t, + accscalar_t, + scalar_t, + ELEMENTS_PER_ITER>( + input.const_data_ptr(), + z.has_value() ? z.value().const_data_ptr() : nullptr, + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + shift.defined() ? shift.const_data_ptr() : nullptr, + output.mutable_data_ptr(), + reduction_size, + stride, + fuse_relu); + sycl_kernel_submit(global_range, local_range, queue, caller); + }); + } +} + +template +struct BatchNormElementwiseLoopsFunctor { + scalar_t operator()( + scalar_t input, + acc_t weight, + acc_t bias, + acc_t mean, + acc_t invstd) const { + volatile acc_t res = ((acc_t)input - mean) * weight * invstd + bias; + return res; + } +}; + +void batch_norm_elemt_kernel( + Tensor& out, + const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const Tensor& mean_, + const Tensor& invstd_) { + switch (batch_norm_choose_impl(self)) { + case Impl::Contiguous: { + c10::MaybeOwned weight = + at::borrow_from_optional_tensor(weight_opt); + c10::MaybeOwned bias = at::borrow_from_optional_tensor(bias_opt); + at::native::resize_output(out, self.sizes()); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + self.scalar_type(), + "batch_norm_elementwise_xpu", + [&] { + using accscalar_t = acc_type; + const bool mixed_type = is_mixed_type(self, *weight, *bias); + if (mixed_type) { + batch_norm_elemt_template( + out, self, *weight, *bias, mean_, invstd_); + } else { + batch_norm_elemt_template( + out, self, *weight, *bias, mean_, invstd_); + } + }); + return; + } + case Impl::ChannelsLast: { + auto weight = at::borrow_from_optional_tensor(weight_opt); + auto bias = at::borrow_from_optional_tensor(bias_opt); + + if (resize_output_check(out, self.sizes())) { + resize_impl_xpu_( + out.unsafeGetTensorImpl(), self.sizes(), self.strides()); + } + if ((out.strides() == self.strides()) && + (!weight->defined() || weight->is_contiguous()) && + (!bias->defined() || bias->is_contiguous()) && + (!mean_.defined() || mean_.is_contiguous()) && + (!invstd_.defined() || invstd_.is_contiguous())) { + batch_norm_elemt_channels_last_template( + out, self, *weight, *bias, mean_, invstd_); + return; + } + [[fallthrough]]; + } + case Impl::General: { + const int64_t ndim = self.dim(); + DimVector sizes(ndim, 1), strides(ndim, 0); + // Helper to convert 1d tensors to an nd tensor that broadcasts with + // input All elements go into the channel dimension + auto as_nd = [&](const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1); + sizes[1] = t.sizes()[0]; + strides[1] = t.strides()[0]; + return t.as_strided(sizes, strides); + }; + + auto weight = weight_opt.has_value() && weight_opt->defined() + ? as_nd(*weight_opt) + : at::scalar_tensor(1, mean_.options()); + auto bias = bias_opt.has_value() && bias_opt->defined() + ? as_nd(*bias_opt) + : at::scalar_tensor(0, mean_.options()); + auto mean = as_nd(mean_); + auto invstd = as_nd(invstd_); + + auto iter = TensorIteratorConfig() + .add_output(out) + .add_input(self) + .add_input(weight) + .add_input(bias) + .add_input(mean) + .add_input(invstd) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + self.scalar_type(), + "batch_norm_elementwise_xpu", + [&] { + using acc_t = acc_type; + auto f = BatchNormElementwiseLoopsFunctor(); + gpu_kernel(iter, f); + }); + return; + } + } +} + +// ====================== batch_norm_backward_reduce ====================== + +template < + int SIMD, + typename input_scalar_t, + typename stat_scalar_t, + typename stat_accscalar_t, + typename index_t> +struct BatchNormBackwardReduceKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<2> item) const { + index_t plane = item.get_group(1); + + stat_accscalar_t r_mean = mean_[plane]; + stat_accscalar_t factor = invstd_[plane]; + + GradOp< + input_scalar_t, + stat_accscalar_t, + GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t>> + g(r_mean, input_, grad_output_); + int num_sg = item.get_local_range(1) * item.get_local_range(0) / SIMD; + auto res = plane_reduce>( + item, g, grad_output_, plane, num_sg, local_sum_); + + if (item.get_local_id(1) == 0) { + if (grad_weight_.size(0) > 0) { + auto grad_weight = grad_weight_; + grad_weight[plane] = static_cast(res.v2 * factor); + } + if (grad_bias_.size(0) > 0) { + auto grad_bias = grad_bias_; + grad_bias[plane] = static_cast(res.v1); + } + if (sum_dy_.size(0) > 0) { + auto sum_dy = sum_dy_; + sum_dy[plane] = static_cast(res.v1); + } + if (sum_dy_xmu_.size(0) > 0) { + auto sum_dy_xmu = sum_dy_xmu_; + sum_dy_xmu[plane] = static_cast(res.v2); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + local_sum_ = sycl_local_acc_t>( + sycl::range<1>{(size_t)get_max_group_size(SIMD)}, cgh); + } + + BatchNormBackwardReduceKernelFunctor( + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> input, + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> grad_output, + GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> mean, + GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> invstd, + GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> sum_dy, + GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> sum_dy_xmu, + GenericPackedTensorAccessor + grad_weight, + GenericPackedTensorAccessor + grad_bias) + : input_(input), + grad_output_(grad_output), + mean_(mean), + invstd_(invstd), + sum_dy_(sum_dy), + sum_dy_xmu_(sum_dy_xmu), + grad_weight_(grad_weight), + grad_bias_(grad_bias) {} + + private: + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> + input_; + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> + grad_output_; + GenericPackedTensorAccessor + mean_; + GenericPackedTensorAccessor + invstd_; + GenericPackedTensorAccessor + sum_dy_; + GenericPackedTensorAccessor + sum_dy_xmu_; + GenericPackedTensorAccessor + grad_weight_; + GenericPackedTensorAccessor + grad_bias_; + sycl_local_acc_t> local_sum_; +}; + +// supports CF and CL +template +std::tuple batch_norm_backward_reduce_template( + const Tensor& grad_out_, + const Tensor& input_, + const Tensor& mean_, + const Tensor& invstd_, + const Tensor& weight_, + const bool input_g, + const bool weight_g, + const bool bias_g) { + using stat_accscalar_t = acc_type; + int64_t n_input = input_.size(1); + Tensor sum_dy_; + Tensor sum_dy_xmu_; + Tensor grad_weight_; + Tensor grad_bias_; + + auto input_reshaped = input_.reshape( + {input_.size(0), + input_.size(1), + -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (input_g) { + sum_dy_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + sum_dy_xmu_ = at::empty_like(mean_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (weight_g) { + grad_weight_ = at::empty({n_input}, weight_.options()); + } + if (bias_g) { + grad_bias_ = at::empty({n_input}, weight_.options()); + } + + auto input = + get_packed_accessor( + input_reshaped, "input"); + auto grad_output = + get_packed_accessor( + grad_output_reshaped, "grad_output"); + auto grad_weight = + packed_accessor_or_dummy( + grad_weight_, "grad_weight"); + auto grad_bias = + packed_accessor_or_dummy( + grad_bias_, "grad_bias"); + auto mean = + packed_accessor_or_dummy( + mean_, "mean"); + auto invstd = + packed_accessor_or_dummy( + invstd_, "invstd"); + auto sum_dy = + packed_accessor_or_dummy( + sum_dy_, "sum_dy"); + auto sum_dy_xmu = + packed_accessor_or_dummy( + sum_dy_xmu_, "sum_dy_xmu"); + + auto batch_size = input_reshaped.size(0); + auto feature_size = input_reshaped.size(2); + auto& queue = getCurrentSYCLQueue(); + int simd = get_prefer_simd( + input_reshaped.size(1), input_reshaped.size(0) * input_reshaped.size(1)); + int max_wg_size = get_max_group_size(simd); + int wg_size_y = std::min(last_pow2(batch_size), max_wg_size / simd); + int wg_size_x = std::min( + std::max(get_num_threads(feature_size, simd), simd), + max_wg_size / wg_size_y); + sycl::range<2> local_range(wg_size_y, wg_size_x); + sycl::range<2> global_range(1 * wg_size_y, n_input * wg_size_x); + + if (simd == SIMD32) { + auto caller = BatchNormBackwardReduceKernelFunctor< + SIMD32, + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + index_t>( + input, + grad_output, + mean, + invstd, + sum_dy, + sum_dy_xmu, + grad_weight, + grad_bias); + sycl_kernel_submit(global_range, local_range, queue, caller); + } else { + auto caller = BatchNormBackwardReduceKernelFunctor< + SIMD16, + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + index_t>( + input, + grad_output, + mean, + invstd, + sum_dy, + sum_dy_xmu, + grad_weight, + grad_bias); + sycl_kernel_submit(global_range, local_range, queue, caller); + } + return std::make_tuple(sum_dy_, sum_dy_xmu_, grad_weight_, grad_bias_); +} + +template +inline void merge_group_vertical_backward( + item_t item, + T& sum_dy, + T& sum_dy_xmu, + TACC shmem_sum_dy, + TACC shmem_sum_dy_xmu) { + // write to shared memory + auto address_base = item.get_local_linear_id(); + int local_id_y = item.get_local_id(0); + +#pragma unroll + for (int offset = item.get_local_range(0) / 2; offset > 0; offset >>= 1) { + if (local_id_y < offset * 2) { + shmem_sum_dy[address_base] = sum_dy; + shmem_sum_dy_xmu[address_base] = sum_dy_xmu; + } + item.barrier(sycl_local_fence); + if (local_id_y < offset && local_id_y + offset < item.get_local_range(0)) { + auto address = address_base + offset * item.get_local_range(1); + + sum_dy += shmem_sum_dy[address]; + sum_dy_xmu += shmem_sum_dy_xmu[address]; + } + } +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t> +struct BatchNormBackwardReduceChannelsLastKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + // hide latency with concurrency + accscalar_t sum_dy[PARALLEL_LOADS]; + accscalar_t sum_dy_xmu[PARALLEL_LOADS]; + +#pragma unroll + for (int i = 0; i < PARALLEL_LOADS; i++) { + sum_dy[i] = accscalar_t(0); + sum_dy_xmu[i] = accscalar_t(0); + } + // tensor dimension (m,c) + + // loop along m dimension + int inner_loop_stride = item.get_local_range(0) * item.get_group_range(0); + + // offset along m dimension + int m_offset = item.get_global_id(0); + int c_offset = item.get_global_id(1); + + if (c_offset >= stride_ || m_offset >= reduction_size_) { + return; + } + + int loop_count = + 1 + (reduction_size_ - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride_ + c_offset; + int address_increment = inner_loop_stride * stride_; + + auto r_mean = mean_[c_offset]; + auto factor = inv_std_[c_offset]; + + for (int i = 0; i < loop_count; i++) { + accscalar_t x_input[PARALLEL_LOADS]; + accscalar_t x_grad_output[PARALLEL_LOADS]; + + // load multiple data in +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride_ && m_offset < reduction_size_) { + x_input[j] = input_[address_base]; + x_grad_output[j] = grad_output_[address_base]; + } else { + x_input[j] = accscalar_t(0); + x_grad_output[j] = accscalar_t(0); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + + // calculate sum_dy / sum_dy_xmu +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + sum_dy[j] += x_grad_output[j]; + sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); + } + } + +#pragma unroll + for (int j = 1; j < PARALLEL_LOADS; j++) { + sum_dy[0] += sum_dy[j]; + sum_dy_xmu[0] += sum_dy_xmu[j]; + } + + // release array of registers + auto sum_dy_th = sum_dy[0]; + auto sum_dy_xmu_th = sum_dy_xmu[0]; + + merge_group_vertical_backward( + item, sum_dy_th, sum_dy_xmu_th, shmem_sum_dy_, shmem_sum_dy_xmu_); + + auto nwg_y = item.get_group_range(0); + int tid_y = item.get_local_id(0); + + if (nwg_y > 1) { + volatile accscalar_t* staging_sum_dy = staging_data_; + volatile accscalar_t* staging_sum_dy_xmu = + &staging_data_[stride_ * nwg_y]; + + address_base = c_offset + item.get_group(0) * stride_; + // write data to staging_data; + if (tid_y == 0 && c_offset < stride_) { + staging_sum_dy[address_base] = sum_dy_th; + staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; + } + + item.barrier(sycl_local_fence); + + // mark group done + if (item.get_local_linear_id() == 0) { + sycl_atomic_ref_rlx_dev_global_t count( + semaphores_[item.get_group(1)]); + int old = count.fetch_add( + 1, sycl_mem_odr_acq_rel + /* , default memory scope is device */); + is_last_group_done_[0] = (old == (nwg_y - 1)); + } + + item.barrier(sycl_local_fence); + + // check that all data is now available in global memory + if (is_last_group_done_[0]) { + sum_dy_th = accscalar_t(0.0); + sum_dy_xmu_th = accscalar_t(0.0); + + for (int y = tid_y; y < nwg_y; y += item.get_local_range(0)) { + address_base = c_offset + y * stride_; + sum_dy_th += + (c_offset < stride_ ? staging_sum_dy[address_base] + : accscalar_t(0.0)); + sum_dy_xmu_th += + (c_offset < stride_ ? staging_sum_dy_xmu[address_base] + : accscalar_t(0.0)); + } + + merge_group_vertical_backward( + item, sum_dy_th, sum_dy_xmu_th, shmem_sum_dy_, shmem_sum_dy_xmu_); + if (tid_y == 0 && c_offset < stride_) { + if (grad_bias_ != nullptr) { + grad_bias_[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight_ != nullptr) { + grad_weight_[c_offset] = + static_cast(sum_dy_xmu_th * factor); + } + sum_dy_o_[c_offset] = sum_dy_th; + sum_dy_xmu_o_[c_offset] = sum_dy_xmu_th; + } + } + } else { + if (item.get_group(0) == 0 && tid_y == 0 && c_offset < stride_) { + if (grad_bias_ != nullptr) { + grad_bias_[c_offset] = static_cast(sum_dy_th); + } + if (grad_weight_ != nullptr) { + grad_weight_[c_offset] = + static_cast(sum_dy_xmu_th * factor); + } + sum_dy_o_[c_offset] = sum_dy_th; + sum_dy_xmu_o_[c_offset] = sum_dy_xmu_th; + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shmem_sum_dy_ = sycl_local_acc_t( + sycl::range<1>{(size_t)get_max_group_size()}, cgh); + shmem_sum_dy_xmu_ = sycl_local_acc_t( + sycl::range<1>{(size_t)get_max_group_size()}, cgh); + is_last_group_done_ = sycl_local_acc_t(sycl::range<1>{1}, cgh); + } + + BatchNormBackwardReduceChannelsLastKernelFunctor( + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ grad_output, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + accscalar_t* __restrict__ sum_dy_o, + accscalar_t* __restrict__ sum_dy_xmu_o, + layerscalar_t* __restrict__ grad_weight, + layerscalar_t* __restrict__ grad_bias, + volatile accscalar_t* staging_data, + int* semaphores, + const int reduction_size, + const int stride) + : input_(input), + grad_output_(grad_output), + mean_(mean), + inv_std_(inv_std), + sum_dy_o_(sum_dy_o), + sum_dy_xmu_o_(sum_dy_xmu_o), + grad_weight_(grad_weight), + grad_bias_(grad_bias), + staging_data_(staging_data), + semaphores_(semaphores), + reduction_size_(reduction_size), + stride_(stride) {} + + private: + const scalar_t* __restrict__ input_; + const scalar_t* __restrict__ grad_output_; + const accscalar_t* __restrict__ mean_; + const accscalar_t* __restrict__ inv_std_; + accscalar_t* __restrict__ sum_dy_o_; + accscalar_t* __restrict__ sum_dy_xmu_o_; + layerscalar_t* __restrict__ grad_weight_; + layerscalar_t* __restrict__ grad_bias_; + volatile accscalar_t* staging_data_; + int* semaphores_; + const int reduction_size_; + const int stride_; + sycl_local_acc_t shmem_sum_dy_; + sycl_local_acc_t shmem_sum_dy_xmu_; + sycl_local_acc_t is_last_group_done_; +}; + +std::tuple +batch_norm_backward_reduce_channels_last_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const bool input_g, + const bool weight_g, + const bool bias_g) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::Tensor sumn_dy = at::zeros({stride}, mean.options()); + at::Tensor sum_dy_xmu = at::zeros({stride}, mean.options()); + + at::Tensor grad_weight; + at::Tensor grad_bias; + if (weight.defined()) { + grad_weight = at::zeros({stride}, weight.options()); + grad_bias = at::zeros({stride}, weight.options()); + } else { + // because I cannot return an uninitialized at::Tensor + grad_weight = at::empty({0}, mean.options()); + grad_bias = at::empty({0}, mean.options()); + } + + auto config = get_adaptive_launch_config( + reduction_size, stride, false, ELEMENTS_PER_WORK_ITEM); + auto global_range = std::get<0>(config); + auto local_range = std::get<1>(config); + auto wg_size_y = local_range[0]; + auto wg_size_x = local_range[1]; + auto nwg_y = global_range[0] / wg_size_y; + auto nwg_x = global_range[1] / wg_size_x; + + at::Tensor staging_data; + at::Tensor semaphores; + if (nwg_y > 1) { + staging_data = at::empty({(long)(2 * stride * nwg_y)}, mean.options()); + semaphores = at::zeros({(long)nwg_x}, input.options().dtype(at::kInt)); + } + auto& queue = getCurrentSYCLQueue(); + + if (weight.defined() && input.scalar_type() != weight.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batchnorm_backward_reduce_xpu", + [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = nwg_y > 1 + ? staging_data.mutable_data_ptr() + : nullptr; + int* semaphores_ptr = + nwg_y > 1 ? semaphores.mutable_data_ptr() : nullptr; + auto caller = BatchNormBackwardReduceChannelsLastKernelFunctor< + ELEMENTS_PER_ITER, + scalar_t, + accscalar_t, + accscalar_t>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + grad_weight.mutable_data_ptr(), + grad_bias.mutable_data_ptr(), + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + sycl_kernel_submit(global_range, local_range, queue, caller); + }); + } else { + if (weight.defined()) { + TORCH_CHECK( + input.scalar_type() == weight.scalar_type(), + "batchnorm_backward_reduce: input.scalar_type() ", + input.scalar_type(), + " is not supported with weight.scalar_type() ", + weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batchnorm_backward_reduce_xpu", + [&] { + using accscalar_t = at::acc_type; + accscalar_t* staging_data_ptr = nwg_y > 1 + ? staging_data.mutable_data_ptr() + : nullptr; + int* semaphores_ptr = + nwg_y > 1 ? semaphores.mutable_data_ptr() : nullptr; + auto caller = BatchNormBackwardReduceChannelsLastKernelFunctor< + ELEMENTS_PER_ITER, + scalar_t, + accscalar_t, + scalar_t>( + input.const_data_ptr(), + grad_output.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + sumn_dy.mutable_data_ptr(), + sum_dy_xmu.mutable_data_ptr(), + weight.defined() ? grad_weight.mutable_data_ptr() + : nullptr, + weight.defined() ? grad_bias.mutable_data_ptr() + : nullptr, + staging_data_ptr, + semaphores_ptr, + reduction_size, + stride); + sycl_kernel_submit(global_range, local_range, queue, caller); + }); + } + + return std::make_tuple(sumn_dy, sum_dy_xmu, grad_weight, grad_bias); +} + +std::tuple batch_norm_backward_reduce_kernel( + const Tensor& grad_output, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const c10::optional& weight_opt, + bool input_g, + bool weight_g, + bool bias_g) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + if (canUse32BitIndexMath(grad_output) && + batch_norm_use_channels_last_kernels(grad_output) && + batch_norm_use_channels_last_kernels(input) && + (!weight.defined() || weight.is_contiguous()) && mean.is_contiguous() && + invstd.is_contiguous()) { + return batch_norm_backward_reduce_channels_last_template( + grad_output, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + return AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + grad_output.scalar_type(), + "batch_norm_backward_reduce_xpu", + [&] { + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK( + mean_st == invstd_st, + "mean and invstd need to have the same data types"); + const bool mixed_type = is_mixed_type(input, weight); + using accscalar_t = acc_type; + + if (canUse32BitIndexMath(grad_output)) { + if (mixed_type) { + return batch_norm_backward_reduce_template< + scalar_t, + accscalar_t, + int32_t>( + grad_output, + input, + mean, + invstd, + weight, + input_g, + weight_g, + bias_g); + } else { + return batch_norm_backward_reduce_template< + scalar_t, + scalar_t, + int32_t>( + grad_output, + input, + mean, + invstd, + weight, + input_g, + weight_g, + bias_g); + } + } else { + if (mixed_type) { + return batch_norm_backward_reduce_template< + scalar_t, + accscalar_t, + int64_t>( + grad_output, + input, + mean, + invstd, + weight, + input_g, + weight_g, + bias_g); + } else { + return batch_norm_backward_reduce_template< + scalar_t, + scalar_t, + int64_t>( + grad_output, + input, + mean, + invstd, + weight, + input_g, + weight_g, + bias_g); + } + } + }); +} + +// ====================== batch_norm_backward_elemt ====================== + +template < + typename input_scalar_t, + typename stat_scalar_t, + typename stat_accscalar_t, + typename index_t, + bool USE_COUNTS = false> +struct BatchNormBackwardElemtKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + stat_accscalar_t norm_fct; + if constexpr (USE_COUNTS) { + int64_t total_numel = 0; + for (int i = 0; i < world_size_; i++) { + total_numel += numel_[i]; + } + norm_fct = static_cast(1) / + static_cast(total_numel); + } else { + norm_fct = norm_fct_; + } + + index_t plane = item.get_group(1); + + if (plane >= input_.size(1)) { + return; + } + + stat_accscalar_t m_c = mean_[plane]; + stat_accscalar_t m_dy_c = sum_dy_[plane] * norm_fct; + stat_accscalar_t factor_1_c = invstd_[plane]; + stat_accscalar_t factor_2_c = weight_.size(0) > 0 + ? static_cast(weight_[plane]) + : stat_accscalar_t(1); + factor_2_c *= factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu_[plane] * norm_fct; + + index_t bs = input_.size(0); + index_t fs = input_.size(2); + + index_t bstep = item.get_local_range(0) * item.get_group_range(0); + for (index_t batch = item.get_global_id(0); batch < bs; batch += bstep) { + auto g_i = grad_input_[batch][plane]; + auto g_o = grad_output_[batch][plane]; + auto i = input_[batch][plane]; + for (index_t feature = item.get_local_id(1); feature < fs; + feature += item.get_local_range(1)) { + g_i[feature] = static_cast( + (g_o[feature] - m_dy_c - (i[feature] - m_c) * factor_1_c) * + factor_2_c); + } + } + } + BatchNormBackwardElemtKernelFunctor( + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> input, + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> grad_output, + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> mean, + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> invstd, + const GenericPackedTensorAccessor< + stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> weight, + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> sum_dy, + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> sum_dy_xmu, + GenericPackedTensorAccessor + grad_input, + const stat_accscalar_t norm_fct, + const int* __restrict__ numel = nullptr, + const int world_size = 0) + : input_(input), + grad_output_(grad_output), + mean_(mean), + invstd_(invstd), + weight_(weight), + sum_dy_(sum_dy), + sum_dy_xmu_(sum_dy_xmu), + grad_input_(grad_input), + norm_fct_(norm_fct), + numel_(numel), + world_size_(world_size) {} + + private: + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> + input_; + const GenericPackedTensorAccessor< + input_scalar_t, + 3, + DefaultPtrTraits, + index_t> + grad_output_; + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> + mean_; + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> + invstd_; + const GenericPackedTensorAccessor + weight_; + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> + sum_dy_; + const GenericPackedTensorAccessor< + stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> + sum_dy_xmu_; + GenericPackedTensorAccessor + grad_input_; + const stat_accscalar_t norm_fct_; + const int* __restrict__ numel_; + const int world_size_; +}; + +template +Tensor batch_norm_backward_elemt_template( + const Tensor& grad_out_, + const Tensor& input_, + const Tensor& mean_, + const Tensor& invstd_, + const Tensor& weight_, + const Tensor& sum_dy_, + const Tensor& sum_dy_xmu_) { + using stat_accscalar_t = acc_type; + int64_t n_input = input_.size(1); + auto input_reshaped = input_.reshape( + {input_.size(0), + input_.size(1), + -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = + at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = + get_packed_accessor( + input_reshaped, "input"); + auto grad_input = + get_packed_accessor( + grad_input_reshaped, "grad_input"); + auto grad_output = + get_packed_accessor( + grad_output_reshaped, "grad_output"); + auto mean = + packed_accessor_or_dummy( + mean_, "mean"); + auto invstd = + packed_accessor_or_dummy( + invstd_, "invstd"); + auto weight = + packed_accessor_or_dummy( + weight_, "weight"); + auto sum_dy = + packed_accessor_or_dummy( + sum_dy_, "sum_dy"); + auto sum_dy_xmu = + packed_accessor_or_dummy( + sum_dy_xmu_, "sum_dy_xmu"); + + auto& queue = getCurrentSYCLQueue(); + int tf = std::max( + get_num_threads(input.size(2) / 4), + std::min(get_num_threads(input.size(2)), 64)); + int tb = std::max(64 / tf, 1); + int nwg_x = input.size(1); + int nwg_y = std::max( + 1, + std::min( + (256 * 1024) / input.size(1), (input.size(0) + tb - 1) / tb)); + nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); + auto reduction_size = input_.numel() / n_input; + auto norm_fct = static_cast(1.0 / reduction_size); + + sycl::range<2> local_range(tb, tf); + sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); + + auto caller = BatchNormBackwardElemtKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + index_t>( + input, + grad_output, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + grad_input, + norm_fct); + sycl_kernel_submit(global_range, local_range, queue, caller); + + return grad_input_reshaped.view(input_.sizes()); +} + +template +Tensor batch_norm_backward_elemt_template( + const Tensor& grad_out_, + const Tensor& input_, + const Tensor& mean_, + const Tensor& invstd_, + const Tensor& weight_, + const Tensor& sum_dy_, + const Tensor& sum_dy_xmu_, + const Tensor& count) { + using stat_accscalar_t = at::acc_type; + auto input_reshaped = input_.reshape( + {input_.size(0), + input_.size(1), + -1}); // internally we merge the feature dimensions + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + auto grad_input_reshaped = + at::empty_like(input_reshaped, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + auto input = + get_packed_accessor( + input_reshaped, "input"); + auto grad_input = + get_packed_accessor( + grad_input_reshaped, "grad_input"); + auto grad_output = + get_packed_accessor( + grad_output_reshaped, "grad_output"); + auto mean = + packed_accessor_or_dummy( + mean_, "mean"); + auto invstd = + packed_accessor_or_dummy( + invstd_, "invstd"); + auto weight = + packed_accessor_or_dummy( + weight_, "weight"); + auto sum_dy = + packed_accessor_or_dummy( + sum_dy_, "sum_dy"); + auto sum_dy_xmu = + packed_accessor_or_dummy( + sum_dy_xmu_, "sum_dy_xmu"); + + auto& queue = getCurrentSYCLQueue(); + + int tf = std::max( + get_num_threads(input.size(2) / 4), + std::min(get_num_threads(input.size(2)), 64)); + int tb = std::max(64 / tf, 1); + int nwg_x = input.size(1); + int nwg_y = std::max( + 1, + std::min( + (256 * 1024) / input.size(1), (input.size(0) + tb - 1) / tb)); + nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); + + sycl::range<2> local_range(tb, tf); + sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); + + auto caller = BatchNormBackwardElemtKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + index_t, + true>( + input, + grad_output, + mean, + invstd, + weight, + sum_dy, + sum_dy_xmu, + grad_input, + 0, + count.const_data_ptr(), + count.numel()); + sycl_kernel_submit(global_range, local_range, queue, caller); + + return grad_input_reshaped.view(input_.sizes()); +} + +template < + int PARALLEL_LOADS, + typename scalar_t, + typename accscalar_t, + typename layerscalar_t, + bool USE_COUNTS = false> +struct BatchNormBackwardElemtChannelsLastKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + accscalar_t norm_fct; + if constexpr (USE_COUNTS) { + int64_t total_numel = 0; + for (int i = 0; i < world_size_; i++) { + total_numel += numel_[i]; + } + norm_fct = + static_cast(1) / static_cast(total_numel); + } else { + norm_fct = norm_fct_; + } + + // tensor dimension (m,c) + // loop along m dimension + int inner_loop_stride = item.get_local_range(0) * item.get_group_range(0); + + // offset along m dimension + int m_offset = item.get_global_id(0); + int c_offset = item.get_global_id(1); + + if (c_offset >= stride_ || m_offset >= reduction_size_) { + return; + } + + auto m_c = mean_[c_offset]; + auto m_dy_c = sum_dy_[c_offset] * norm_fct; + auto factor_1_c = inv_std_[c_offset]; + auto factor_2_c = + (weight_ == nullptr ? accscalar_t(1.0) + : static_cast(weight_[c_offset])) * + factor_1_c; + factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu_[c_offset] * norm_fct; + + int loop_count = + 1 + (reduction_size_ - 1) / (inner_loop_stride * PARALLEL_LOADS); + int address_base = m_offset * stride_ + c_offset; + int address_increment = inner_loop_stride * stride_; + + for (int i = 0; i < loop_count; i++) { +#pragma unroll + for (int j = 0; j < PARALLEL_LOADS; j++) { + if (c_offset < stride_ && m_offset < reduction_size_) { + grad_input_[address_base] = static_cast( + (static_cast(grad_output_[address_base]) - m_dy_c - + (static_cast(input_[address_base]) - m_c) * + factor_1_c) * + factor_2_c); + } + m_offset += inner_loop_stride; + address_base += address_increment; + } + } + } + + BatchNormBackwardElemtChannelsLastKernelFunctor( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ input, + const accscalar_t* __restrict__ mean, + const accscalar_t* __restrict__ inv_std, + const layerscalar_t* __restrict__ weight, + const accscalar_t* __restrict__ sum_dy, + const accscalar_t* __restrict__ sum_dy_xmu, + scalar_t* __restrict__ grad_input, + const accscalar_t norm_fct, + const int reduction_size, + const int stride, + const int* __restrict__ numel = nullptr, + const int64_t world_size = 0) + : grad_output_(grad_output), + input_(input), + mean_(mean), + inv_std_(inv_std), + weight_(weight), + sum_dy_(sum_dy), + sum_dy_xmu_(sum_dy_xmu), + grad_input_(grad_input), + norm_fct_(norm_fct), + reduction_size_(reduction_size), + stride_(stride), + numel_(numel), + world_size_(world_size) {} + + private: + const scalar_t* __restrict__ grad_output_; + const scalar_t* __restrict__ input_; + const accscalar_t* __restrict__ mean_; + const accscalar_t* __restrict__ inv_std_; + const layerscalar_t* __restrict__ weight_; + const accscalar_t* __restrict__ sum_dy_; + const accscalar_t* __restrict__ sum_dy_xmu_; + scalar_t* __restrict__ grad_input_; + const accscalar_t norm_fct_; + const int reduction_size_; + const int stride_; + const int* __restrict__ numel_; + const int64_t world_size_; +}; + +at::Tensor batch_norm_backward_elemt_channels_last_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + auto norm_fct = 1.0 / reduction_size; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + auto config = get_adaptive_launch_config( + reduction_size, stride, true, ELEMENTS_PER_WORK_ITEM); + auto global_range = std::get<0>(config); + auto local_range = std::get<1>(config); + auto& queue = getCurrentSYCLQueue(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batchnorm_backward_element_xpu", + [&] { + using accscalar_t = at::acc_type; + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + auto caller = BatchNormBackwardElemtChannelsLastKernelFunctor< + ELEMENTS_PER_ITER, + scalar_t, + accscalar_t, + accscalar_t>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + sycl_kernel_submit(global_range, local_range, queue, caller); + } else { + auto caller = BatchNormBackwardElemtChannelsLastKernelFunctor< + ELEMENTS_PER_ITER, + scalar_t, + accscalar_t, + scalar_t>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + static_cast(norm_fct), + reduction_size, + stride); + sycl_kernel_submit(global_range, local_range, queue, caller); + } + }); + + return grad_input; +} + +at::Tensor batch_norm_backward_elemt_channels_last_template( + const at::Tensor& grad_output, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& inv_std, + const at::Tensor& weight, + const at::Tensor& sum_dy, + const at::Tensor& sum_dy_xmu, + const at::Tensor& count) { + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + // Input is guarunteed to be channels-last compatible + at::Tensor grad_input = at::empty_like(input); + + auto config = get_adaptive_launch_config( + reduction_size, stride, false, ELEMENTS_PER_WORK_ITEM); + auto global_range = std::get<0>(config); + auto local_range = std::get<1>(config); + auto& queue = getCurrentSYCLQueue(); + + if (weight.defined() && weight.scalar_type() != input.scalar_type()) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batchnorm_backward_element_xpu", + [&] { + using accscalar_t = acc_type; + auto caller = BatchNormBackwardElemtChannelsLastKernelFunctor< + ELEMENTS_PER_ITER, + scalar_t, + accscalar_t, + accscalar_t, + true>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.const_data_ptr(), + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + 0, + reduction_size, + stride, + count.const_data_ptr(), + count.numel()); + sycl_kernel_submit(global_range, local_range, queue, caller); + }); + } else { + if (weight.defined()) { + TORCH_CHECK( + input.scalar_type() == weight.scalar_type(), + "batchnorm_backward_element: input.scalar_type() ", + input.scalar_type(), + " is not supported with weight.scalar_type() ", + weight.scalar_type()); + } + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "batchnorm_backward_element_xpu", + [&] { + using accscalar_t = acc_type; + auto caller = BatchNormBackwardElemtChannelsLastKernelFunctor< + ELEMENTS_PER_ITER, + scalar_t, + accscalar_t, + scalar_t, + true>( + grad_output.const_data_ptr(), + input.const_data_ptr(), + mean.const_data_ptr(), + inv_std.const_data_ptr(), + weight.defined() ? weight.const_data_ptr() : nullptr, + sum_dy.const_data_ptr(), + sum_dy_xmu.const_data_ptr(), + grad_input.mutable_data_ptr(), + 0, + reduction_size, + stride, + count.const_data_ptr(), + count.numel()); + sycl_kernel_submit(global_range, local_range, queue, caller); + }); + } + + return grad_input; +} + +Tensor batch_norm_backward_elemt_kernel( + const Tensor& self, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const c10::optional& weight_opt, + const Tensor& sum_dy, + const Tensor& sum_dy_xmu, + const Tensor& count) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + if (canUse32BitIndexMath(self) && + batch_norm_use_channels_last_kernels(self) && + batch_norm_use_channels_last_kernels(input)) { + return batch_norm_backward_elemt_channels_last_template( + self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + + return AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "batch_norm_backward_elemt_xpu", + [&] { + auto mean_st = mean.dtype(); + auto invstd_st = invstd.dtype(); + TORCH_CHECK( + mean_st == invstd_st, + "mean and invstd need to have the same data types"); + bool is_half_float = + std::is_same::value && mean_st == at::kFloat; + bool is_bfloat16_float = std::is_same::value && + mean_st == at::kFloat; + using accscalar_t = acc_type; + if (canUse32BitIndexMath(self)) { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_elemt_template< + scalar_t, + accscalar_t, + int32_t>( + self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } else { + return batch_norm_backward_elemt_template< + scalar_t, + scalar_t, + int32_t>( + self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + } else { + if (is_half_float || is_bfloat16_float) { + return batch_norm_backward_elemt_template< + scalar_t, + accscalar_t, + int64_t>( + self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } else { + return batch_norm_backward_elemt_template< + scalar_t, + scalar_t, + int64_t>( + self, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + } + }); +} + +// ====================== batch_norm_update_stats ====================== + +template +struct BatchNormUpdateStatsFunctor { + std::tuple operator()( + acc_t mean, + acc_t var, + scalar_t running_mean, + scalar_t running_var) const { + const auto unbiased_var = var * bessel_correction_factor; + return std::tuple{ + mean * momentum + (1 - momentum) * running_mean, + unbiased_var * momentum + (1 - momentum) * running_var, + }; + } + + BatchNormUpdateStatsFunctor( + const acc_t bessel_correction_factor, + const acc_t momentum) + : bessel_correction_factor(bessel_correction_factor), + momentum(momentum) {} + + private: + const acc_t bessel_correction_factor; + const acc_t momentum; +}; + +void batch_norm_update_stats( + const Tensor& save_mean, + const Tensor& save_var, + const Tensor& running_mean, + const Tensor& running_var, + double momentum_, + int64_t N) { + auto iter = TensorIteratorConfig() + .add_output(running_mean) + .add_output(running_var) + .add_input(save_mean) + .add_input(save_var) + .add_input(running_mean) + .add_input(running_var) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + running_mean.scalar_type(), + "batch_norm_update_stats_xpu", + [&] { + using acc_t = acc_type; + const auto bessel_correction_factor = static_cast( + static_cast(N) / static_cast(N - 1)); + const auto momentum = static_cast(momentum_); + BatchNormUpdateStatsFunctor f( + bessel_correction_factor, momentum); + gpu_kernel_multiple_outputs(iter, f); + }); +} + +void batch_norm_mean_var( + const Tensor& self, + Tensor& save_mean, + Tensor& save_var) { + // NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored. + const double dummy_epsilon = 1e-5; + switch (batch_norm_choose_impl(self)) { + case Impl::Contiguous: { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_xpu", [&] { + batch_norm_stats_template( + save_mean, save_var, self, dummy_epsilon); + }); + return; + } + case Impl::ChannelsLast: { + if ((!save_mean.defined() || save_mean.is_contiguous()) && + (!save_var.defined() || save_var.is_contiguous())) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, self.scalar_type(), "batch_norm_stats_xpu", [&] { + batch_norm_stats_channels_last_template( + save_mean, save_var, self, dummy_epsilon); + }); + return; + } + [[fallthrough]]; + } + case Impl::General: { + const int64_t ndim = self.dim(); + DimVector reduce_dims(ndim - 1); + reduce_dims[0] = 0; + for (int64_t i = 2; i < ndim; ++i) { + reduce_dims[i - 1] = i; + } + + // For some reason this isn't an actual operator but it exists anyway... + var_mean_out( + save_var, + save_mean, + self, + /*dims=*/reduce_dims, + /*unbiased=*/false, + /*keepdim=*/false); + return; + } + } +} + +std::tuple batch_norm_update_stats_kernel( + const Tensor& self, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + double momentum) { + c10::MaybeOwned running_mean = + at::borrow_from_optional_tensor(running_mean_opt); + c10::MaybeOwned running_var = + at::borrow_from_optional_tensor(running_var_opt); + + const int64_t n_input = self.size(1); + TORCH_CHECK( + self.numel() != 0, + "input tensor must have at least one element, but got input_sizes = ", + self.sizes()); + + auto options = + self.options().dtype(at::toAccumulateType(self.scalar_type(), true)); + + auto save_mean = at::empty({n_input}, options); + auto save_var = at::empty({n_input}, options); + + batch_norm_mean_var(self, save_mean, save_var); + TORCH_CHECK(running_mean->defined() == running_var->defined()); + if (running_mean->defined()) { + const int64_t N = self.numel() / save_mean.numel(); + batch_norm_update_stats( + save_mean, save_var, *running_mean, *running_var, momentum, N); + } + return std::tuple(save_mean, save_var); +} + +// ====================== native_batch_norm ====================== + +template +struct BatchNormUpdateStatsAndInvertFunctor { + std::tuple operator()( + acc_t mean, + acc_t var, + scalar_t running_mean, + scalar_t running_var) const { + const acc_t unbiased_var = var * bessel_correction_factor_; + volatile acc_t a = mean * momentum_ + (1 - momentum_) * (acc_t)running_mean; + volatile acc_t b = + unbiased_var * momentum_ + (1 - momentum_) * (acc_t)running_var; + volatile acc_t c = c10::xpu::compat::rsqrt(var + eps_); + return std::tuple{a, b, c}; + } + + BatchNormUpdateStatsAndInvertFunctor( + const acc_t bessel_correction_factor, + const acc_t eps, + const acc_t momentum) + : bessel_correction_factor_(bessel_correction_factor), + eps_(eps), + momentum_(momentum) {} + + private: + const acc_t bessel_correction_factor_; + const acc_t eps_; + const acc_t momentum_; +}; + +void batch_norm_update_stats_and_invert( + const Tensor& save_mean, + const Tensor& save_var, + const Tensor& running_mean, + const Tensor& running_var, + double momentum_, + double epsilon, + int64_t N) { + auto iter = TensorIteratorConfig() + .add_output(running_mean) + .add_output(running_var) + .add_output(save_var) + .add_const_input(save_mean) + .add_input(save_var) + .add_input(running_mean) + .add_input(running_var) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + running_mean.scalar_type(), + "batch_norm_update_stats_and_invert_xpu", + [&] { + using acc_t = acc_type; + const auto bessel_correction_factor = static_cast( + static_cast(N) / static_cast(N - 1)); + const auto eps = static_cast(epsilon); + const auto momentum = static_cast(momentum_); + BatchNormUpdateStatsAndInvertFunctor f( + bessel_correction_factor, eps, momentum); + gpu_kernel_multiple_outputs(iter, f); + }); +} + +template +struct BatchNormCalcInvstdFunctor { + acc_t operator()(scalar_t var) const { + volatile acc_t v = var + eps_; + return c10::xpu::compat::rsqrt(v); + } + + BatchNormCalcInvstdFunctor(acc_t eps) : eps_(eps) {} + + private: + acc_t eps_; +}; + +void batch_norm_calc_invstd( + const Tensor& out_invstd, + const Tensor& running_var, + double epsilon) { + auto iter = TensorIteratorConfig() + .add_output(out_invstd) + .add_input(running_var) + .check_all_same_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + running_var.scalar_type(), + "batch_norm_invert_std_xpu", + [&] { + using acc_t = at::acc_type; + auto eps = static_cast(epsilon); + BatchNormCalcInvstdFunctor f(eps); + gpu_kernel(iter, f); + }); +} + +template +struct BatchNormElementwiseFunctor { + scalar_t operator()( + scalar_t input, + acc_t weight, + acc_t bias, + acc_t mean, + acc_t invstd) const { + return ((input - mean) * invstd) * weight + bias; + } +}; + +void batch_norm_elementwise( + const Tensor& out, + const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const Tensor& mean_, + const Tensor& invstd_) { + switch (batch_norm_choose_impl(self)) { + case Impl::Contiguous: { + c10::MaybeOwned weight = + at::borrow_from_optional_tensor(weight_opt); + c10::MaybeOwned bias = at::borrow_from_optional_tensor(bias_opt); + at::native::resize_output(out, self.sizes()); + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + self.scalar_type(), + "batch_norm_elementwise_xpu", + [&] { + using accscalar_t = at::acc_type; + const bool mixed_type = is_mixed_type(self, *weight, *bias); + if (mixed_type) { + batch_norm_elemt_template( + out, self, *weight, *bias, mean_, invstd_); + } else { + batch_norm_elemt_template( + out, self, *weight, *bias, mean_, invstd_); + } + }); + return; + } + case Impl::ChannelsLast: { + auto weight = at::borrow_from_optional_tensor(weight_opt); + auto bias = at::borrow_from_optional_tensor(bias_opt); + + if (resize_output_check(out, self.sizes())) { + resize_impl_xpu_( + out.unsafeGetTensorImpl(), self.sizes(), self.strides()); + } + if ((out.strides() == self.strides()) && + (!weight->defined() || weight->is_contiguous()) && + (!bias->defined() || bias->is_contiguous()) && + (!mean_.defined() || mean_.is_contiguous()) && + (!invstd_.defined() || invstd_.is_contiguous())) { + batch_norm_elemt_channels_last_template( + out, self, *weight, *bias, mean_, invstd_); + return; + } + [[fallthrough]]; + } + case Impl::General: { + const int64_t ndim = self.dim(); + DimVector sizes(ndim, 1), strides(ndim, 0); + // Helper to convert 1d tensors to an nd tensor that broadcasts with input + // All elements go into the channel dimension + auto as_nd = [&](const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1); + sizes[1] = t.sizes()[0]; + strides[1] = t.strides()[0]; + return t.as_strided(sizes, strides); + }; + + auto weight = weight_opt.has_value() && weight_opt->defined() + ? as_nd(*weight_opt) + : at::scalar_tensor(1, mean_.options()); + auto bias = bias_opt.has_value() && bias_opt->defined() + ? as_nd(*bias_opt) + : at::scalar_tensor(0, mean_.options()); + auto mean = as_nd(mean_); + auto invstd = as_nd(invstd_); + + auto iter = TensorIteratorConfig() + .add_output(out) + .add_input(self) + .add_input(weight) + .add_input(bias) + .add_input(mean) + .add_input(invstd) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, + kHalf, + self.scalar_type(), + "batch_norm_elementwise_xpu", + [&] { + using acc_t = at::acc_type; + BatchNormElementwiseFunctor f; + gpu_kernel(iter, f); + }); + return; + } + } +} + +std::tuple batch_norm_kernel( + const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + bool train, + double momentum, + double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_invstd) { + const bool has_running_mean = + (running_mean_opt.has_value() && running_mean_opt->defined()); + const bool has_running_var = + (running_var_opt.has_value() && running_var_opt->defined()); + TORCH_CHECK(has_running_mean == has_running_var); + + if (train) { + batch_norm_mean_var(self, save_mean, save_invstd); + if (has_running_mean) { + const int64_t N = self.numel() / save_mean.numel(); + batch_norm_update_stats_and_invert( + save_mean, + save_invstd, + *running_mean_opt, + *running_var_opt, + momentum, + epsilon, + N); + } else { + batch_norm_calc_invstd(save_invstd, save_invstd, epsilon); + } + } else { + TORCH_CHECK(has_running_mean); + at::native::resize_output(save_mean, running_mean_opt->sizes()); + save_mean.copy_(*running_mean_opt, /*non_blocking=*/true); + batch_norm_calc_invstd(save_invstd, running_var_opt.value(), epsilon); + } + + batch_norm_elementwise( + output, self, weight_opt, bias_opt, save_mean, save_invstd); + return std::tuple(output, save_mean, save_invstd); +} + +// ====================== native_batch_norm_bw ====================== + +template < + int SIMD, + typename input_scalar_t, + typename stat_scalar_t, + typename stat_accscalar_t, + typename index_t> +struct BatchNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<2> item) const { + index_t plane = item.get_group(1); + index_t N = grad_output_.size(0) * grad_output_.size(2); + + stat_accscalar_t mean, invstd; + if (train_) { + mean = save_mean_[plane]; + invstd = save_invstd_[plane]; + } else { + mean = static_cast(running_mean_[plane]); + invstd = + static_cast(1) / + std::sqrt( + static_cast(running_var_[plane]) + epsilon_); + } + + stat_accscalar_t weight_val = weight_.size(0) > 0 + ? static_cast(weight_[plane]) + : stat_accscalar_t(1); + stat_accscalar_t norm = stat_accscalar_t(1) / N; + + // Compute two values across (batch, x/y/z) in one pass: + // 1. Sum(grad_output) + // 2. DotProduct(input - mean, grad_output) + GradOp< + input_scalar_t, + stat_accscalar_t, + GenericPackedTensorAccessor< + const input_scalar_t, + 3, + DefaultPtrTraits, + index_t>> + g(mean, input_, grad_output_); + int num_sg = item.get_local_range(1) * item.get_local_range(0) / SIMD; + auto res = plane_reduce>( + item, g, grad_output_, plane, num_sg, local_sum_); + + stat_accscalar_t grad_output_sum = res.v1; + stat_accscalar_t dot_p = res.v2; + + stat_accscalar_t grad_mean = grad_output_sum * norm; + stat_accscalar_t proj_scale = dot_p * norm * invstd * invstd; + stat_accscalar_t grad_scale = invstd * weight_val; + + auto grad_input = grad_input_; + if (grad_input_.data() != NULL) { + for (int batch = item.get_local_id(0); batch < grad_output_.size(0); + batch += item.get_local_range(0)) { + for (int x = item.get_local_id(1); x < grad_output_.size(2); + x += item.get_local_range(1)) { + input_scalar_t go = grad_output_[batch][plane][x]; + if (train_) { + stat_accscalar_t inp = input_[batch][plane][x]; + stat_accscalar_t proj = (inp - mean) * proj_scale; + grad_input[batch][plane][x] = static_cast( + (go - proj - grad_mean) * grad_scale); + } else { + grad_input[batch][plane][x] = + static_cast(go * grad_scale); + } + } + } + } + + if (grad_weight_.size(0) > 0) { + if (item.get_local_id(1) == 0) { + auto grad_weight = grad_weight_; + grad_weight[plane] = static_cast(dot_p * invstd); + } + } + + if (grad_bias_.size(0) > 0) { + if (item.get_local_id(1) == 0) { + auto grad_bias = grad_bias_; + grad_bias[plane] = static_cast(grad_output_sum); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + local_sum_ = sycl_local_acc_t>( + sycl::range<1>{(size_t)get_max_group_size(SIMD)}, cgh); + } + + BatchNormBackwardKernelFunctor( + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + DefaultPtrTraits, + index_t> input, + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + DefaultPtrTraits, + index_t> grad_output, + GenericPackedTensorAccessor + grad_input, + GenericPackedTensorAccessor + grad_weight, + GenericPackedTensorAccessor + grad_bias, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> weight, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> running_mean, + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> running_var, + const GenericPackedTensorAccessor< + const stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> save_mean, + const GenericPackedTensorAccessor< + const stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> save_invstd, + bool train, + stat_accscalar_t epsilon) + : input_(input), + grad_output_(grad_output), + grad_input_(grad_input), + grad_weight_(grad_weight), + grad_bias_(grad_bias), + weight_(weight), + running_mean_(running_mean), + running_var_(running_var), + save_mean_(save_mean), + save_invstd_(save_invstd), + train_(train), + epsilon_(epsilon) {} + + private: + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + DefaultPtrTraits, + index_t> + input_; + const GenericPackedTensorAccessor< + const input_scalar_t, + 3, + DefaultPtrTraits, + index_t> + grad_output_; + GenericPackedTensorAccessor + grad_input_; + GenericPackedTensorAccessor + grad_weight_; + GenericPackedTensorAccessor + grad_bias_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> + weight_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> + running_mean_; + const GenericPackedTensorAccessor< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t> + running_var_; + const GenericPackedTensorAccessor< + const stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> + save_mean_; + const GenericPackedTensorAccessor< + const stat_accscalar_t, + 1, + DefaultPtrTraits, + index_t> + save_invstd_; + bool train_; + stat_accscalar_t epsilon_; + sycl_local_acc_t> local_sum_; +}; + +template +std::tuple batch_norm_backward_template( + const Tensor& grad_out_, + const Tensor& input_, + const Tensor& weight_, + const Tensor& running_mean_, + const Tensor& running_var_, + const Tensor& save_mean_, + const Tensor& save_invstd_, + bool train, + double epsilon, + std::array grad_input_mask) { + using accscalar_t = acc_type; + Tensor grad_input_; + Tensor grad_input_reshaped; + Tensor grad_weight_; + Tensor grad_bias_; + auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); + auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); + + if (grad_input_mask[0]) { + grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); + } + if (grad_input_mask[1]) { + grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[2]) { + grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + auto input = + get_packed_accessor( + input_reshaped, "input"); + auto grad_output = + get_packed_accessor( + grad_output_reshaped, "grad_output"); + auto grad_input = + packed_accessor_or_dummy( + grad_input_reshaped, "grad_input"); + auto weight = packed_accessor_or_dummy< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t>(weight_, "weight"); + auto grad_weight = + packed_accessor_or_dummy( + grad_weight_, "grad_weight"); + auto grad_bias = + packed_accessor_or_dummy( + grad_bias_, "grad_bias"); + auto running_mean = packed_accessor_or_dummy< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t>(running_mean_, "running_mean"); + auto running_var = packed_accessor_or_dummy< + const stat_scalar_t, + 1, + DefaultPtrTraits, + index_t>(running_var_, "running_var"); + auto save_mean = + packed_accessor_or_dummy( + save_mean_, "save_mean"); + auto save_invstd = + packed_accessor_or_dummy( + save_invstd_, "save_invstd"); + + int simd = get_prefer_simd( + input_reshaped.size(1), input_reshaped.size(0) * input_reshaped.size(1)); + + auto& queue = getCurrentSYCLQueue(); + int max_group_size = get_max_group_size(simd); + int tf = get_num_threads(input.size(2), simd); + int wg_sz_y = std::max(1, max_group_size / tf); + sycl::range<2> local_range(wg_sz_y, tf); + sycl::range<2> global_range(1 * wg_sz_y, input.size(1) * tf); + + if (simd == SIMD32) { + auto caller = BatchNormBackwardKernelFunctor< + SIMD32, + input_scalar_t, + stat_scalar_t, + accscalar_t, + index_t>( + input, + grad_output, + grad_input, + grad_weight, + grad_bias, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + epsilon); + sycl_kernel_submit(global_range, local_range, queue, caller); + } else { + auto caller = BatchNormBackwardKernelFunctor< + SIMD16, + input_scalar_t, + stat_scalar_t, + accscalar_t, + index_t>( + input, + grad_output, + grad_input, + grad_weight, + grad_bias, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + epsilon); + sycl_kernel_submit(global_range, local_range, queue, caller); + } + return std::make_tuple(grad_input_, grad_weight_, grad_bias_); +} + +template +struct BatchNormElementwiseBackwardTrainFunctor { + scalar_t operator()( + scalar_t gO, + scalar_t input, + accscalar_t weight, + accscalar_t mean, + accscalar_t invstd, + accscalar_t xmu, + accscalar_t dy) const { + auto factor_1_c = invstd * invstd * xmu * norm_fct_; + auto factor_2_c = weight * invstd; + auto m_dy_c = dy * norm_fct_; + volatile accscalar_t res = + ((accscalar_t)gO - m_dy_c - ((accscalar_t)input - mean) * factor_1_c) * + factor_2_c; + return res; + } + + BatchNormElementwiseBackwardTrainFunctor(accscalar_t norm_fct) + : norm_fct_(norm_fct) {} + + private: + accscalar_t norm_fct_; +}; + +Tensor batch_norm_elementwise_backward_train( + const Tensor& grad_out, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const Tensor& weight, + const Tensor& sum_dy, + const Tensor& sum_dy_xmu) { + switch (batch_norm_choose_impl(input, grad_out)) { + case Impl::Contiguous: { + return AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batch_norm_backward_elemt_xpu", + [&] { + using accscalar_t = at::acc_type; + const bool mixed_type = is_mixed_type(input, weight); + if (mixed_type) { + return batch_norm_backward_elemt_template< + scalar_t, + accscalar_t, + int32_t>( + grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu); + } else { + return batch_norm_backward_elemt_template< + scalar_t, + scalar_t, + int32_t>( + grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu); + } + }); + } + case Impl::ChannelsLast: { + if ((!weight.defined() || weight.is_contiguous()) && + mean.is_contiguous() && invstd.is_contiguous()) { + return batch_norm_backward_elemt_channels_last_template( + grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu); + } + [[fallthrough]]; + } + case Impl::General: { + const auto ndim = input.dim(); + DimVector sizes(ndim, 1), strides(ndim, 0); + auto as_nd = [&](const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1); + sizes[1] = t.sizes()[0]; + strides[1] = t.strides()[0]; + return t.as_strided(sizes, strides); + }; + auto invstd_nd = as_nd(invstd); + auto mean_nd = as_nd(mean); + auto sum_dy_nd = as_nd(sum_dy); + auto sum_dy_xmu_nd = as_nd(sum_dy_xmu); + auto weight_nd = weight.defined() + ? as_nd(weight) + : at::scalar_tensor(1.0, input.options().dtype(mean.scalar_type())); + + Tensor grad_input = at::empty( + input.sizes(), + grad_out.options().memory_format(input.suggest_memory_format())); + auto iter = TensorIteratorConfig() + .add_output(grad_input) + .add_input(grad_out) + .add_input(input) + .add_input(weight_nd) + .add_input(mean_nd) + .add_input(invstd_nd) + .add_input(sum_dy_xmu_nd) + .add_input(sum_dy_nd) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + grad_out.scalar_type(), + "batch_norm_eval_backward_xpu", + [&] { + using accscalar_t = at::acc_type; + auto norm_fct = + static_cast(1.0 / (input.numel() / input.size(1))); + BatchNormElementwiseBackwardTrainFunctor f( + norm_fct); + gpu_kernel(iter, f); + }); + return grad_input; + } + } + TORCH_INTERNAL_ASSERT(false); +} + +template +struct BatchNormElementwiseBackwardEvalWithWeightfunctor { + scalar_t operator()(scalar_t gO, accscalar_t invstd, accscalar_t weight) + const { + volatile accscalar_t res = (accscalar_t)gO * weight * invstd; + return res; + } +}; + +template +struct BatchNormElementwiseBackwardEvalfunctor { + scalar_t operator()(scalar_t gO, accscalar_t invstd) const { + volatile accscalar_t res = (accscalar_t)gO * invstd; + return res; + } +}; + +Tensor batch_norm_elementwise_backward_eval( + const Tensor& grad_out, + const Tensor& input, + const Tensor& invstd, + const Tensor& weight) { + const auto ndim = input.dim(); + DimVector shape(ndim, 1), strides(ndim, 0); + shape[1] = invstd.sizes()[0]; + strides[1] = invstd.strides()[0]; + auto invstd_nd = invstd.as_strided(shape, strides); + Tensor grad_input = at::empty(input.sizes(), grad_out.options()); + + if (weight.defined()) { + strides[1] = weight.strides()[0]; + auto weight_nd = weight.as_strided(shape, strides); + auto iter = TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(grad_out) + .add_const_input(invstd_nd) + .add_const_input(weight_nd) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + grad_out.scalar_type(), + "batch_norm_eval_backward_xpu", + [&] { + using accscalar_t = at::acc_type; + BatchNormElementwiseBackwardEvalWithWeightfunctor< + scalar_t, + accscalar_t> + f; + gpu_kernel(iter, f); + }); + } else { + auto iter = TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(grad_out) + .add_const_input(invstd_nd) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + grad_out.scalar_type(), + "batch_norm_eval_backward_xpu", + [&] { + using accscalar_t = at::acc_type; + BatchNormElementwiseBackwardEvalfunctor f; + gpu_kernel(iter, f); + }); + } + return grad_input; +} + +std::tuple batch_norm_backward_kernel( + const Tensor& grad_out, + const Tensor& input, + const c10::optional& weight_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + const c10::optional& save_mean_opt, + const c10::optional& save_invstd_opt, + bool train, + double epsilon, + std::array grad_input_mask) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight = at::borrow_from_optional_tensor(weight_opt); + c10::MaybeOwned save_mean = + at::borrow_from_optional_tensor(save_mean_opt); + c10::MaybeOwned save_invstd = + at::borrow_from_optional_tensor(save_invstd_opt); + c10::MaybeOwned running_mean = + at::borrow_from_optional_tensor(running_mean_opt); + c10::MaybeOwned running_var = + at::borrow_from_optional_tensor(running_var_opt); + + const bool needs_reduction = + train || grad_input_mask[1] || grad_input_mask[2]; + + // Fused reduction & elementwise kernel + if (needs_reduction && grad_input_mask[0] && + !batch_norm_use_channels_last_kernels(input) && + canUse32BitIndexMath(input) && canUse32BitIndexMath(grad_out)) { + return AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "batch_norm_backward_xpu", [&] { + using accscalar_t = at::acc_type; + const bool mixed_type = + is_mixed_type(input, *weight, *running_mean, *running_var); + if (mixed_type) { + return batch_norm_backward_template( + grad_out, + input, + *weight, + *running_mean, + *running_var, + *save_mean, + *save_invstd, + train, + epsilon, + grad_input_mask); + } else { + return batch_norm_backward_template( + grad_out, + input, + *weight, + *running_mean, + *running_var, + *save_mean, + *save_invstd, + train, + epsilon, + grad_input_mask); + } + }); + } + + const auto acc_type = at::toAccumulateType(input.scalar_type(), true); + Tensor mean; + TORCH_INTERNAL_ASSERT( + save_mean->defined(), "save_mean should always be defined\n"); + if (save_mean->numel() != 0) { + mean = *save_mean; + } else if (needs_reduction) { + TORCH_CHECK(!train && running_mean->defined()); + mean = (running_mean->scalar_type() == acc_type) + ? *running_mean + : running_mean->to(acc_type); + } + + Tensor invstd; + TORCH_INTERNAL_ASSERT( + save_invstd->defined(), "save_invstd should always be defined\n"); + if (save_invstd->numel() != 0) { + invstd = *save_invstd; + } else { + TORCH_CHECK(!train && running_var->defined()); + auto n_channels = input.sizes()[1]; + invstd = at::empty({n_channels}, input.options().dtype(acc_type)); + batch_norm_calc_invstd(invstd, *running_var, epsilon); + } + + Tensor sum_dy, sum_dy_xmu, grad_weight, grad_bias; + if (needs_reduction) { + std::tie(sum_dy, sum_dy_xmu, grad_weight, grad_bias) = + batch_norm_backward_reduce_kernel( + grad_out, + input, + mean, + invstd, + *weight, + grad_input_mask[0], + grad_input_mask[1], + grad_input_mask[2]); + } + + Tensor grad_input; + if (grad_input_mask[0]) { + if (train) { + // NOTE: sum_dy and sum_dy_xmy are defined, as train implies + // needs_reduction + grad_input = batch_norm_elementwise_backward_train( + grad_out, input, mean, invstd, *weight, sum_dy, sum_dy_xmu); + } else { + grad_input = batch_norm_elementwise_backward_eval( + grad_out, input, invstd, *weight); + } + } + + return std::make_tuple(grad_input, grad_weight, grad_bias); +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.h b/src/ATen/native/xpu/sycl/BatchNormKernels.h new file mode 100644 index 000000000..3bc559b38 --- /dev/null +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +std::tuple batch_norm_stats_kernel( + const Tensor& self, + double epsilon); + +void batch_norm_elemt_kernel( + Tensor& out, + const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const Tensor& mean_, + const Tensor& invstd_); + +std::tuple batch_norm_backward_reduce_kernel( + const Tensor& grad_output, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const c10::optional& weight_opt, + bool input_g, + bool weight_g, + bool bias_g); + +Tensor batch_norm_backward_elemt_kernel( + const Tensor& self, + const Tensor& input, + const Tensor& mean, + const Tensor& invstd, + const c10::optional& weight_opt, + const Tensor& sum_dy, + const Tensor& sum_dy_xmu, + const Tensor& count); + +std::tuple batch_norm_update_stats_kernel( + const Tensor& self, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + double momentum); + +std::tuple batch_norm_kernel( + const Tensor& self, + const c10::optional& weight_opt, + const c10::optional& bias_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + bool train, + double momentum, + double epsilon, + Tensor& output, + Tensor& save_mean, + Tensor& save_invstd); + +std::tuple batch_norm_backward_kernel( + const Tensor& grad_out, + const Tensor& input, + const c10::optional& weight_opt, + const c10::optional& running_mean_opt, + const c10::optional& running_var_opt, + const c10::optional& save_mean_opt, + const c10::optional& save_invstd_opt, + bool train, + double epsilon, + std::array grad_input_mask); + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/ATen/native/xpu/sycl/BinaryGeometricKernels.cpp b/src/ATen/native/xpu/sycl/BinaryGeometricKernels.cpp new file mode 100644 index 000000000..c93afe4bf --- /dev/null +++ b/src/ATen/native/xpu/sycl/BinaryGeometricKernels.cpp @@ -0,0 +1,32 @@ +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace xpu { + +template +struct HypotFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + return std::hypot(a, b); + } +}; + +void hypot_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.common_dtype(), + "hypot_xpu", + [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, HypotFunctor()); + }); +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/ATen/native/xpu/sycl/BinaryGeometricKernels.h b/src/ATen/native/xpu/sycl/BinaryGeometricKernels.h new file mode 100644 index 000000000..e37dd6dbf --- /dev/null +++ b/src/ATen/native/xpu/sycl/BinaryGeometricKernels.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void hypot_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ComplexKernels.cpp b/src/ATen/native/xpu/sycl/ComplexKernels.cpp new file mode 100644 index 000000000..686f5e2d3 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ComplexKernels.cpp @@ -0,0 +1,23 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#include + +namespace at::native::xpu { + +template +struct ComplexFunctor { + c10::complex operator()(scalar_t a, scalar_t b) const { + return c10::complex(a, b); + } +}; + +void complex_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.input_dtype(0), "complex_xpu", [&]() { + ComplexFunctor f; + gpu_kernel(iter, f); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ComplexKernels.h b/src/ATen/native/xpu/sycl/ComplexKernels.h new file mode 100644 index 000000000..990bcd14e --- /dev/null +++ b/src/ATen/native/xpu/sycl/ComplexKernels.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void complex_kernel(TensorIterator& iter); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/GridSampler.cpp b/src/ATen/native/xpu/sycl/GridSampler.cpp index bd8e43056..746e2b035 100644 --- a/src/ATen/native/xpu/sycl/GridSampler.cpp +++ b/src/ATen/native/xpu/sycl/GridSampler.cpp @@ -11,8 +11,8 @@ #include #include -#include "GridSampler.h" -#include "UpSample.h" +#include +#include namespace at::native::xpu { diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index f7fa13146..5b2981f12 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -478,6 +478,12 @@ void index_put_deterministic_kernel( const Tensor& value, bool accumulate, bool unsafe) { + TORCH_CHECK( + !indices.empty() || is_expandable_to(value.sizes(), self.sizes()), + "shape mismatch: value tensor of shape ", + value.sizes(), + " cannot be broadcast to indexing result of shape ", + self.sizes()); if (indices.size() > (size_t)self.dim()) { TORCH_CHECK_INDEX( false, @@ -563,6 +569,9 @@ void index_put_deterministic_kernel( }); if (permuted) self.copy_(src_.permute(inversePerm)); + else if (!self_contiguous) { + self.copy_(self_); + } } } diff --git a/src/ATen/native/xpu/sycl/LerpKernels.cpp b/src/ATen/native/xpu/sycl/LerpKernels.cpp new file mode 100644 index 000000000..b0f480ac3 --- /dev/null +++ b/src/ATen/native/xpu/sycl/LerpKernels.cpp @@ -0,0 +1,92 @@ +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct LerpTensorComplexFunctor { + using opmath_t = at::opmath_type; + scalar_t operator()(scalar_t self_val, scalar_t end_val, scalar_t weight_val) + const { + opmath_t self_val_f = self_val; + opmath_t end_val_f = end_val; + opmath_t weight_val_f = weight_val; + return lerp(self_val, end_val, weight_val); + } +}; + +template +struct LerpTensorFunctor { + scalar_t operator()(scalar_t self_val, scalar_t end_val, scalar_t weight_val) + const { + return lerp(self_val, end_val, weight_val); + } +}; + +template +struct LerpScalarComplexFunctor { + using opmath_t = at::opmath_type; + scalar_t operator()(scalar_t self_val, scalar_t end_val) const { + opmath_t self_val_f = self_val; + opmath_t end_val_f = end_val; + return lerp(self_val, end_val, weight_val_); + } + + LerpScalarComplexFunctor(opmath_t weight_val) : weight_val_(weight_val) {} + + private: + opmath_t weight_val_; +}; + +template +struct LerpScalarFunctor { + using opmath_t = at::opmath_type; + scalar_t operator()(scalar_t self_val, scalar_t end_val) const { + return lerp(self_val, end_val, weight_val_); + } + + LerpScalarFunctor(opmath_t weight_val) : weight_val_(weight_val) {} + + private: + opmath_t weight_val_; +}; + +void lerp_tensor_kernel(at::TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_xpu", [&] { + gpu_kernel(iter, LerpTensorComplexFunctor()); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "lerp_xpu", [&] { + gpu_kernel(iter, LerpTensorFunctor()); + }); + } +} + +void lerp_scalar_kernel( + at::TensorIteratorBase& iter, + const c10::Scalar& weight) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "lerp_xpu", [&] { + using opmath_t = at::opmath_type; + auto weight_val = weight.to(); + gpu_kernel(iter, LerpScalarComplexFunctor(weight_val)); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "lerp_xpu", [&] { + using opmath_t = at::opmath_type; + auto weight_val = weight.to(); + gpu_kernel(iter, LerpScalarFunctor(weight_val)); + }); + } +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/LerpKernels.h b/src/ATen/native/xpu/sycl/LerpKernels.h new file mode 100644 index 000000000..c455adee8 --- /dev/null +++ b/src/ATen/native/xpu/sycl/LerpKernels.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void lerp_tensor_kernel(TensorIteratorBase& iter); + +void lerp_scalar_kernel(TensorIteratorBase& iter, const c10::Scalar& weight); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ResizeKernel.cpp b/src/ATen/native/xpu/sycl/ResizeKernel.cpp new file mode 100644 index 000000000..195d526d8 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ResizeKernel.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include + +namespace at::native::xpu { + +void resize_bytes_xpu(StorageImpl* storage, size_t size_bytes) { + TORCH_CHECK( + storage->resizable(), "Trying to resize storage that is not resizable"); + auto allocator = storage->allocator(); + TORCH_CHECK( + allocator != nullptr, "Trying to resize storage without an allocator"); + + c10::Device device = storage->device(); + + if (size_bytes == 0) { + storage->set_data_ptr_noswap(at::DataPtr(nullptr, device)); + storage->set_nbytes(0); + return; + } + + c10::xpu::XPUGuard guard(device.index()); + at::DataPtr data = allocator->allocate(size_bytes); + if (storage->data_ptr()) { + at::globalContext().lazyInitXPU(); + auto q = at::xpu::getCurrentSYCLQueue(); + + q.memcpy( + data.get(), storage->data(), std::min(storage->nbytes(), size_bytes)); + } + + // Destructively overwrite data_ptr + storage->set_data_ptr_noswap(std::move(data)); + storage->set_nbytes(size_bytes); +} + +static inline void maybe_resize_storage_xpu( + TensorImpl* self, + size_t new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in Resize.h) + if (self->numel() == 0) { + return; + } + + const Storage& storage = self->unsafe_storage(); + TORCH_CHECK(storage, "Tensor: invalid null storage"); + if (new_size_bytes > storage.nbytes()) { + resize_bytes_xpu(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +TensorImpl* resize_impl_xpu_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride, + bool device_guard = true) { + if (self->sizes() == size && (!stride || self->strides() == stride)) { + return self; + } + + // NB: We don't need to hold the device guard when calling from TH + at::xpu::OptionalXPUGuard guard; + if (device_guard) { + guard.set_index(self->storage().device().index()); + } + + const auto itemsize = self->dtype().itemsize(); + const auto storage_offset = self->storage_offset(); + size_t storage_size = 1; + if (stride) { + self->set_sizes_and_strides(size, *stride); + storage_size = at::detail::computeStorageNbytes( + size, *stride, itemsize, storage_offset); + } else { + self->set_sizes_contiguous(size); + storage_size = at::detail::computeStorageNbytesContiguous( + size, itemsize, storage_offset); + } + maybe_resize_storage_xpu(self, storage_size); + + return self; +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ResizeKernel.h b/src/ATen/native/xpu/sycl/ResizeKernel.h new file mode 100644 index 000000000..5cef196ed --- /dev/null +++ b/src/ATen/native/xpu/sycl/ResizeKernel.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TensorImpl* resize_impl_xpu_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride, + bool device_guard = true); + +} diff --git a/src/ATen/native/xpu/sycl/UpSampleBicubic2dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleBicubic2dKernels.cpp index ef17e8eea..504f28d7b 100644 --- a/src/ATen/native/xpu/sycl/UpSampleBicubic2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleBicubic2dKernels.cpp @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include "UpSample.h" namespace at::native::xpu { diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp new file mode 100644 index 000000000..1ab02435a --- /dev/null +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp @@ -0,0 +1,414 @@ +#pragma clang diagnostic push +#pragma GCC diagnostic push +// Avoid SYCL compiler return-type error +#pragma clang diagnostic ignored "-Wreturn-type" +#pragma GCC diagnostic ignored "-Wreturn-type" + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native::xpu { + +template +struct UpsampleBilinear2dKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + int index = item.get_global_linear_id(); + + if (index < n_) { + const int output_x = index % output_width_; + const int output_y = index / output_width_; + + const accscalar_t h1r = area_pixel_compute_source_index( + rheight_, output_y, align_corners_, /*cubic=*/false); + const int h1 = h1r; + const int h1p = (h1 < input_height_ - 1) ? 1 : 0; + const accscalar_t h1lambda = h1r - h1; + const accscalar_t h0lambda = static_cast(1) - h1lambda; + + const accscalar_t w1r = area_pixel_compute_source_index( + rwidth_, output_x, align_corners_, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < input_width_ - 1) ? 1 : 0; + const accscalar_t w1lambda = w1r - w1; + const accscalar_t w0lambda = static_cast(1) - w1lambda; + auto odata = out_data_acc_; + for (int n = 0; n < nbatch_; n++) { + for (int c = 0; c < channels_; ++c) { + const accscalar_t val = h0lambda * + (w0lambda * in_data_acc_[n][c][h1][w1] + + w1lambda * in_data_acc_[n][c][h1][w1 + w1p]) + + h1lambda * + (w0lambda * in_data_acc_[n][c][h1 + h1p][w1] + + w1lambda * in_data_acc_[n][c][h1 + h1p][w1 + w1p]); + odata[n][c][output_y][output_x] = static_cast(val); + } + } + } + } + UpsampleBilinear2dKernelFunctor( + const int n, + const accscalar_t rheight, + const accscalar_t rwidth, + const bool align_corners, + const PackedTensorAccessor idata_acc, + PackedTensorAccessor odata_acc, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels) + : n_(n), + rheight_(rheight), + rwidth_(rwidth), + align_corners_(align_corners), + in_data_acc_(idata_acc), + out_data_acc_(odata_acc), + input_height_(input_height), + input_width_(input_width), + output_height_(output_height), + output_width_(output_width), + nbatch_(nbatch), + channels_(channels) {} + + private: + const int n_; + const accscalar_t rheight_; + const accscalar_t rwidth_; + const bool align_corners_; + const PackedTensorAccessor in_data_acc_; + PackedTensorAccessor out_data_acc_; + int64_t input_height_; + int64_t input_width_; + int64_t output_height_; + int64_t output_width_; + int64_t nbatch_; + int64_t channels_; +}; + +template +void launch_upsample_bilinear2d_kernel( + const int n, + const accscalar_t rheight, + const accscalar_t rwidth, + const bool align_corners, + const PackedTensorAccessor idata_acc, + PackedTensorAccessor odata_acc, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels) { + auto queue = getCurrentSYCLQueue(); + int64_t wg_size = syclMaxWorkGroupSize(); + int num_group = at::ceil_div(n, (int)wg_size); + + UpsampleBilinear2dKernelFunctor kfn( + n, + rheight, + rwidth, + align_corners, + idata_acc, + odata_acc, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels); + + sycl_kernel_submit( + sycl::range<1>(num_group * wg_size), sycl::range<1>(wg_size), queue, kfn); +} + +size_t idx( + const size_t nc, + const size_t height, + const size_t width, + const size_t y, + const size_t x) { + return (nc * height + y) * width + x; +} + +template +struct UpsampleBilinear2dBackwardKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + for (size_t index = + item.get_local_id(0) + item.get_group(0) * item.get_local_range(0); + index < o_numel_; + index += item.get_local_range(0) * item.get_group_range(0)) { + size_t index_temp = index; + const int w2 = index_temp % output_width_; + index_temp /= output_width_; + const int h2 = index_temp % output_height_; + const size_t nc = index_temp / output_height_; + + const accscalar_t h1r = area_pixel_compute_source_index( + rheight_, h2, align_corners_, /*cubic=*/false); + const int h1 = h1r; + const int h1p = (h1 < input_height_ - 1) ? 1 : 0; + const accscalar_t h1lambda = h1r - h1; + const accscalar_t h0lambda = static_cast(1) - h1lambda; + + const accscalar_t w1r = area_pixel_compute_source_index( + rwidth_, w2, align_corners_, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < input_width_ - 1) ? 1 : 0; + const accscalar_t w1lambda = w1r - w1; + const accscalar_t w0lambda = static_cast(1) - w1lambda; + + const scalar_t d2val = out_data_[index]; + + atomicAdd( + (sycl_global_ptr< + scalar_t>)(in_data_ + idx(nc, input_height_, input_width_, h1, w1)), + static_cast(h0lambda * w0lambda * d2val)); + + atomicAdd( + (sycl_global_ptr< + scalar_t>)(in_data_ + idx(nc, input_height_, input_width_, h1, w1 + w1p)), + static_cast(h0lambda * w1lambda * d2val)); + + atomicAdd( + (sycl_global_ptr< + scalar_t>)(in_data_ + idx(nc, input_height_, input_width_, h1 + h1p, w1)), + static_cast(h1lambda * w0lambda * d2val)); + + atomicAdd( + (sycl_global_ptr< + scalar_t>)(in_data_ + idx(nc, input_height_, input_width_, h1 + h1p, w1 + w1p)), + static_cast(h1lambda * w1lambda * d2val)); + } + } + UpsampleBilinear2dBackwardKernelFunctor( + const size_t nc, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + const accscalar_t rheight, + const accscalar_t rwidth, + const bool align_corners, + scalar_t* in_data, + const scalar_t* out_data, + const size_t o_numel, + const size_t i_numel) + : nc_(nc), + input_height_(input_height), + input_width_(input_width), + output_height_(output_height), + output_width_(output_width), + nbatch_(nbatch), + channels_(channels), + rheight_(rheight), + rwidth_(rwidth), + align_corners_(align_corners), + in_data_(in_data), + out_data_(out_data), + o_numel_(o_numel), + i_numel_(i_numel) {} + + private: + const size_t nc_; + int64_t input_height_; + int64_t input_width_; + int64_t output_height_; + int64_t output_width_; + int64_t nbatch_; + int64_t channels_; + const accscalar_t rheight_; + const accscalar_t rwidth_; + const bool align_corners_; + scalar_t* in_data_; + const scalar_t* out_data_; + const size_t o_numel_; + const size_t i_numel_; +}; + +template +void launch_upsample_bilinear2d_backward_kernel( + const size_t nc, + int64_t input_height, + int64_t input_width, + int64_t output_height, + int64_t output_width, + int64_t nbatch, + int64_t channels, + const accscalar_t rheight, + const accscalar_t rwidth, + const bool align_corners, + scalar_t* idata, + const scalar_t* odata) { + auto queue = getCurrentSYCLQueue(); + int64_t wg_size = syclMaxWorkGroupSize(); + + const size_t o_numel = nc * output_width * output_height; + const size_t i_numel = nc * input_width * input_height; + + const size_t num_kernels = nc * output_width * output_height; + int num_group = at::ceil_div((int64_t)num_kernels, (int64_t)wg_size); + + UpsampleBilinear2dBackwardKernelFunctor kfn( + nc, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + rheight, + rwidth, + align_corners, + idata, + odata, + o_numel, + i_numel); + sycl_kernel_submit( + sycl::range<1>(num_group * wg_size), sycl::range<1>(wg_size), queue, kfn); +} + +void upsample_bilinear2d_out_kernel( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2}; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + int output_height = output_size[0]; + int output_width = output_size[1]; + + int nbatch = input.size(0); + int channels = input.size(1); + int input_height = input.size(2); + int input_width = input.size(3); + + if (input.sizes() == output.sizes()) { + output.copy_(input); + return; + } + + const int num_kernels = output_height * output_width; + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "upsample_bilinear2d_xpu", + [&] { + using accscalar_t = acc_type; + auto idata_acc = input.packed_accessor64(); + auto odata_acc = output.packed_accessor64(); + + const accscalar_t rheight = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t rwidth = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + // TODO:a faster kernel for channel last + launch_upsample_bilinear2d_kernel( + num_kernels, + rheight, + rwidth, + align_corners, + idata_acc, + odata_acc, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels); + }); +} + +void upsample_bilinear2d_backward_out_kernel( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + TensorArg grad_input_arg{grad_input, "grad_input", 1}, + grad_output_arg{grad_output_, "grad_output_", 2}; + checkAllSameGPU(__func__, {grad_output_arg, grad_input_arg}); + + int output_height = output_size[0]; + int output_width = output_size[1]; + + int nbatch = input_size[0]; + int channels = input_size[1]; + int input_height = input_size[2]; + int input_width = input_size[3]; + + if (grad_input.numel() == 0) { + return; + } + + grad_input.zero_(); + + if (grad_output_.sizes() == grad_input.sizes()) { + grad_input.copy_(grad_output_); + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + grad_output_.scalar_type(), + "upsample_bilinear2d_backward_xpu", + [&] { + using accscalar_t = acc_type; + + // TODO: using PackedTensorAccessor instead of copy + Tensor grad_input_c = grad_input.is_contiguous() + ? grad_input + : at::zeros(grad_input.sizes(), grad_input.options()); + Tensor grad_output = grad_output_.contiguous(); + + scalar_t* idata = grad_input_c.data_ptr(); + scalar_t* odata = grad_output.data_ptr(); + + const accscalar_t rheight = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t rwidth = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + // TODO: a faster kernel for channel last + launch_upsample_bilinear2d_backward_kernel( + nbatch * channels, + input_height, + input_width, + output_height, + output_width, + nbatch, + channels, + rheight, + rwidth, + align_corners, + idata, + odata); + + if (!grad_input.is_contiguous()) { + grad_input.copy_(grad_input_c); + } + }); +} + +} // namespace at::native::xpu + +#pragma GCC diagnostic pop +#pragma clang diagnostic pop diff --git a/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h new file mode 100644 index 000000000..3f75f79cf --- /dev/null +++ b/src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void upsample_bilinear2d_out_kernel( + Tensor& output, + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w); + +void upsample_bilinear2d_backward_out_kernel( + Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest1dKernels.h b/src/ATen/native/xpu/sycl/UpSampleNearest1dKernels.h index 2b98215f3..bb6dd83ff 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest1dKernels.h +++ b/src/ATen/native/xpu/sycl/UpSampleNearest1dKernels.h @@ -1,7 +1,9 @@ +#pragma once + #include -#include -namespace at::native { -namespace xpu { +#include + +namespace at::native::xpu { void upsample_nearest1d_kernel( Tensor& output, @@ -18,6 +20,4 @@ void upsample_nearest1d_backward_kernel( std::optional scales, bool is_exact); -} // namespace xpu - -} // namespace at::native +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest2dKernels.h b/src/ATen/native/xpu/sycl/UpSampleNearest2dKernels.h index 9adf2e73a..7d11e03af 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest2dKernels.h +++ b/src/ATen/native/xpu/sycl/UpSampleNearest2dKernels.h @@ -1,7 +1,9 @@ +#pragma once + #include -#include -namespace at::native { -namespace xpu { +#include + +namespace at::native::xpu { void upsample_nearest2d_kernel( Tensor& output, @@ -20,6 +22,4 @@ void upsample_nearest2d_backward_kernel( c10::optional scales_w, bool is_exact); -} // namespace xpu - -} // namespace at::native +} // namespace at::native::xpu diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index 7e6fd52dc..559526e49 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -91,9 +91,17 @@ # https://github.com/intel/torch-xpu-ops/issues/412 "test_compare_cpu_abs_xpu_bool", + # bilinear interpolate includes large calculation steps, accuracy reduces in half-precision + # Not in CUDA test scope too + "test_compare_cpu_nn_functional_upsample_bilinear_xpu_bfloat16", + "test_compare_cpu_nn_functional_upsample_bilinear_xpu_float16", + # CPU result is not golden reference "test_compare_cpu_nn_functional_group_norm_xpu_bfloat16", "test_compare_cpu_nn_functional_group_norm_xpu_float16", + "test_compare_cpu_nn_functional_batch_norm_xpu_bfloat16", + "test_compare_cpu__batch_norm_with_update_xpu_bfloat16", + "test_compare_cpu__batch_norm_with_update_xpu_float16", # Not implemented operators, aten::upsample_linear1d, aten::upsample_bilinear2d, # aten::upsample_trilinear3d @@ -111,6 +119,11 @@ "test_forward_ad_nn_functional_glu_xpu_float32", # For CI "test_backward_nn_functional_embedding_bag_xpu_float32", + # Precision error. + # Mismatched elements: 1 / 812 (0.1%) + # Greatest absolute difference: 0.03125 at index (610,) (up to 0.001 allowed) + # Greatest relative difference: 0.00396728515625 at index (610,) (up to 0.001 allowed) + "test_compare_cpu_hypot_xpu_bfloat16", ) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 73f06970e..82a674c92 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -88,9 +88,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_out_requires_grad_error_sparse_sampled_addmm_xpu_complex64", "test_out_requires_grad_error_sparse_sampled_addmm_xpu_float32", "test_out_nn_functional_avg_pool2d_xpu_float32", # CUDA xfail. - "test_out_warning__native_batch_norm_legit_xpu", "test_out_warning_nanmean_xpu", - "test_out_warning_native_batch_norm_xpu", "test_out_warning_nn_functional_logsigmoid_xpu", "test_python_ref__refs_div_trunc_rounding_xpu_bfloat16", "test_python_ref__refs_floor_divide_xpu_float16", @@ -217,12 +215,10 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_variant_consistency_eager_nn_functional_rrelu_xpu_float32", "test_variant_consistency_eager_to_sparse_xpu_complex64", "test_variant_consistency_eager_to_sparse_xpu_float32", - "test_compare_cpu__native_batch_norm_legit_xpu_float32", "test_compare_cpu__refs_special_zeta_xpu_float32", "test_compare_cpu_linalg_lu_factor_ex_xpu_float32", "test_compare_cpu_linalg_lu_factor_xpu_float32", "test_compare_cpu_linalg_lu_xpu_float32", - "test_compare_cpu_native_batch_norm_xpu_float32", "test_compare_cpu_special_hermite_polynomial_h_xpu_float32", "test_compare_cpu_special_zeta_xpu_float32", "test_out_cholesky_inverse_xpu_float32", @@ -269,6 +265,10 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_python_ref_executor__refs_pow_executor_aten_xpu_complex32", # Didn't align with CUDA, Unexpected success "test_compare_cpu_nn_functional_grid_sample_xpu_float32", # AssertionError: Tensor-likes are not close! "test_dtypes_nn_functional_batch_norm_without_cudnn_xpu", # AssertionError: The supported dtypes for nn.functional.batch_norm on device type xpu are incorrect! + "test_out_native_batch_norm_xpu_float32", # CUDA XFAIL, The generated sample data does not meet the requirements. + "test_out__native_batch_norm_legit_xpu_float32", # CUDA XFAIL, The generated sample data does not meet the requirements. + "test_dtypes__batch_norm_with_update_xpu", # We are same as CUDA implementation. And CUDA skips these cases. + # Jiterator is only supported on CUDA and ROCm GPUs, none are available. "_jiterator_", @@ -1326,12 +1326,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_InstanceNorm1d_general_xpu", "test_InstanceNorm2d_general_xpu", "test_InstanceNorm3d_general_xpu", - # AssertionError: AssertionError not raised - "test_batchnorm_simple_average_mixed_xpu_bfloat16", - "test_batchnorm_simple_average_mixed_xpu_float16", - "test_batchnorm_simple_average_xpu_float32", - "test_batchnorm_update_stats_xpu", - "test_batchnorm_simple_average_xpu_bfloat16", # AssertionError: False is not true "test_device_mask_xpu", "test_overwrite_module_params_on_conversion_cpu_device_xpu", @@ -1373,6 +1367,9 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_MultiLabelMarginLoss_no_batch_dim_mean_cuda_half", "test_MultiLabelMarginLoss_no_batch_dim_none_cuda_half", "test_MultiLabelMarginLoss_no_batch_dim_sum_cuda_half", + # align CUDA to skip, XPU implementation is not yet supporting uint8 + "test_upsamplingBiMode2d_consistency", + "test_upsamplingBiLinear2d_consistency_interp_size_bug", ) res += launch_test("test_nn_xpu.py", skip_list) @@ -2753,13 +2750,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): ### Error #8 in TestBwdGradientsXPU , totally 2 , RuntimeError: DispatchStub: unsupported device typexpu "test_inplace_grad_conj_physical_xpu_complex128", "test_inplace_gradgrad_conj_physical_xpu_complex128", - # New uts added in PyTorch fail due to XPU implementation bug - # torch.autograd.gradcheck.GradcheckError: Backward is not reentrant, i.e., running backward with same input and grad_output multiple times gives different values, although analytical gradient matches numerical gradient.The tolerance for nondeterminism was 0.0. - # https://github.com/intel/torch-xpu-ops/issues/464 - "test_fn_grad__unsafe_masked_index_xpu_complex128", - "test_fn_grad__unsafe_masked_index_xpu_float64", - "test_fn_gradgrad__unsafe_masked_index_put_accumulate_xpu_complex128", - "test_fn_gradgrad__unsafe_masked_index_put_accumulate_xpu_float64", ) res += launch_test("test_ops_gradients_xpu.py", skip_list) @@ -2913,6 +2903,15 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_cuda_vitals_gpu_only_xpu", # torch.utils.swap_tensors AssertionError: RuntimeError not raised "test_swap_basic", + # Needs pr to enable deterministic implementation for interpolate op + "test_deterministic_interpolate_bilinear_xpu", + + # Precision error + # Fail in high probability in preci. + # Mismatched elements: 1 / 262144 (0.0%) + # Greatest absolute difference: 0.03125 at index (1, 227, 114) (up to 0.01 allowed) + # Greatest relative difference: 0.01495361328125 at index (1, 227, 114) (up to 0.01 allowed) + "test_index_add_correctness", ) res += launch_test("test_torch_xpu.py", skip_list) diff --git a/test/xpu/test_dataloader_xpu.py b/test/xpu/test_dataloader_xpu.py index 5058a3478..cc824665f 100644 --- a/test/xpu/test_dataloader_xpu.py +++ b/test/xpu/test_dataloader_xpu.py @@ -3,9 +3,10 @@ import os import sys import torch +import unittest from torch import multiprocessing as mp from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import parametrize, run_tests from torch.utils.data import ( DataLoader, IterDataPipe, @@ -28,7 +29,36 @@ def _set_allocator_settings(device=None): pass torch.cuda.memory._set_allocator_settings=_set_allocator_settings - from test_dataloader import * + from test_dataloader import ( + TestDatasetRandomSplit, + TestTensorDataset, + TestStackDataset, + TestConcatDataset, + TestProperExitDataset, + TestProperExitIterableDataset, + TestWorkerInfoDataset, + TestMultiEpochDataset, + TestDataLoader, + TestDataLoaderDeviceType, + TestStringDataLoader, + TestDictDataLoader, + TestDataLoaderPersistentWorkers, + TestNamedTupleDataLoader, + TestCustomPinFn, + TestWorkerQueueDataset, + TestIndividualWorkerQueue, + TestSetAffinity, + TestConvAfterFork, + TEST_CUDA_IPC, + collate_into_packed_sequence, + collate_into_packed_sequence_batch_first, + collate_wrapper, + filter_len, + row_processor, + self_module, + supported_multiprocessing_contexts, + _clone_collate, + ) def _test_multiprocessing_iterdatapipe(self, with_dill): # Testing to make sure that function from global scope (e.g. imported from library) can be serialized @@ -154,6 +184,53 @@ def custom_batch_pin_worker(self): self.assertIsInstance(sample, elem_cls) self.assertTrue(sample.is_pinned()) + @parametrize( + "context", + [ctx for ctx in supported_multiprocessing_contexts if ctx is not None], + ) + @unittest.skipIf(not TEST_CUDA_IPC, "CUDA IPC not available") + def nested_tensor_multiprocessing(self, device, context): + # The 'fork' multiprocessing context doesn't work for CUDA so skip it + if "xpu" in device and context == "fork": + # TODO: Skip this better in a better way when the test framework allows + return + + dataset = [ + torch.nested.nested_tensor([torch.randn(5)], device=device) + for _ in range(10) + ] + + pin_memory_settings = [False] + if device == "cpu" and torch.xpu.is_available(): + pin_memory_settings.append(True) + + for pin_memory in pin_memory_settings: + loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + num_workers=4, + collate_fn=_clone_collate, + pin_memory=pin_memory, + multiprocessing_context=context, + ) + + for i, batch in enumerate(loader): + self.assertEqual(batch[0], dataset[i]) + + # Error case: default collate_fn doesn't currently support batches of nested tensors. + # Following the current semantics, we'd need to stack them, which isn't possible atm. + with self.assertRaisesRegex( + RuntimeError, "not currently supported by the default collate_fn" + ): + loader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + num_workers=4, + multiprocessing_context=context, + ) + + next(iter(loader)) + TestDataLoader._test_multiprocessing_iterdatapipe = _test_multiprocessing_iterdatapipe TestDataLoader.test_sequential_pin_memory = sequential_pin_memory TestDataLoader.test_shuffle_pin_memory = shuffle_pin_memory @@ -162,6 +239,7 @@ def custom_batch_pin_worker(self): TestDictDataLoader.test_pin_memory_device = pin_memory_device TestDictDataLoader.test_pin_memory_with_only_device = pin_memory_with_only_device TestCustomPinFn.test_custom_batch_pin = custom_batch_pin + TestDataLoaderDeviceType.test_nested_tensor_multiprocessing = nested_tensor_multiprocessing instantiate_device_type_tests(TestDataLoaderDeviceType, globals(), only_for="xpu", allow_xpu=True) diff --git a/test/xpu/test_nn_xpu.py b/test/xpu/test_nn_xpu.py index 897126128..b91800473 100644 --- a/test/xpu/test_nn_xpu.py +++ b/test/xpu/test_nn_xpu.py @@ -2131,90 +2131,6 @@ def issue_24823_2(): issue_24823_2() TestNNDeviceType.test_grid_sample_large=_test_grid_sample_large -@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) -@parametrize_test("mode", ["bilinear", "bicubic"]) -@parametrize_test("antialias", [True, False]) -@parametrize_test("align_corners", [True, False]) -@parametrize_test("num_channels", [3, 5]) -@parametrize_test("output_size", [32, 600]) -@parametrize_test("check_as_unsqueezed_3d_tensor", [True, False]) -@parametrize_test("non_contig", [False, "sliced", "restrided"]) -@parametrize_test("batch_size", [1, 5]) -def upsamplingBiMode2d_consistency( - self, - device, - memory_format, - mode, - antialias, - align_corners, - num_channels, - output_size, - check_as_unsqueezed_3d_tensor, - non_contig, - batch_size, -): - # Check output value consistency between resized_input_uint8 and resized input_float - if torch.device(device).type == "xpu": - raise SkipTest("XPU implementation is not yet supporting uint8") - - torch.manual_seed(0) - - # - input range is set to [30, 220] for bicubic mode, because the bicubic kernel may create - # [intermediate] values outside of the [0, 255] range, which need - # to be clipped in uint8 path, but not in float path. This isn't - # an issue with bilinear kernel. - input_range = (30, 220) if mode == "bicubic" else (0, 256) - input_ui8 = torch.randint(*input_range, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device) - input_ui8 = input_ui8.contiguous(memory_format=memory_format) - - if non_contig == "sliced": - input_ui8 = input_ui8[:, :, 10:-10, 10:-10] - elif non_contig == "restrided": - input_ui8 = input_ui8[:, :, ::2, ::2] - - if batch_size == 1 and check_as_unsqueezed_3d_tensor: - input_ui8 = input_ui8[0, ...] - input_ui8 = input_ui8[None, ...] - - input_f32 = input_ui8.float() - - output_f32 = F.interpolate( - input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias - ).round().clip(0, 255) - output_ui8 = F.interpolate( - input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias - ) - - if non_contig is False: - self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format)) - - # FIXME if-clause shows the current behaviour which is definitely unexpected. - # Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last - # See for more details: https://github.com/pytorch/pytorch/pull/100373 - if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last: - self.assertTrue(output_ui8.is_contiguous()) - self.assertTrue(output_f32.is_contiguous()) - else: - self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format)) - self.assertTrue(output_f32.is_contiguous(memory_format=memory_format)) - - if mode == "bilinear": - torch.testing.assert_close(output_f32, output_ui8.float(), rtol=0, atol=1) - else: - diff = (output_f32 - output_ui8.float()).abs() - self.assertLess(diff.max(), 15) - - threshold = 2 - percent = 3 - self.assertLess((diff > threshold).float().mean(), percent / 100) - - threshold = 5 - percent = 1 - self.assertLess((diff > threshold).float().mean(), percent / 100) - - self.assertLess(diff.mean(), 0.4) -TestNNDeviceType.test_upsamplingBiMode2d_consistency = upsamplingBiMode2d_consistency - def _test_grid_sample_half_precision(self): def helper(shape_in, shape_out, align_corners): for mode in ('bilinear', 'nearest', 'bicubic'): @@ -2255,6 +2171,75 @@ def helper(shape_in, shape_out, align_corners): # helper((32, 64, 16, 16, 16), (32, 8, 8, 8, 3), False) # grid_sampler_3d is not supported in xpu TestNNDeviceType.test_grid_sample_bfloat16_precision=_test_grid_sample_bfloat16_precision +@parametrize_test("antialias", [True, False]) +@parametrize_test("align_corners", [True, False]) +@parametrize_test("mode", ["bilinear", "bicubic"]) +@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) +def upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory_format): + # Forward AD does not support XLA because XLA tensors don't have storage + check_forward_ad = torch.device(device).type != 'xla' + + kwargs = dict(mode=mode, align_corners=align_corners, antialias=antialias) + # test float scale factor up & downsampling + for scale_factor in [0.5, 1.5, 2]: + in_t = torch.ones( + 2, 3, 8, 8, device=device, + dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() + out_size = int(math.floor(in_t.shape[-1] * scale_factor)) + with warnings.catch_warnings(record=True) as w: + out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs) + expected_out = torch.ones(2, 3, out_size, out_size, device=device, dtype=torch.double) + self.assertEqual(expected_out, out_t) + # Assert that memory format is carried through to the output + self.assertTrue(out_t.is_contiguous(memory_format=memory_format)) + out_t.backward(torch.randn_like(out_t)) + self.assertTrue(in_t.grad.is_contiguous(memory_format=memory_format)) + + if torch.device(device).type == 'xpu': + # Bilinear backward is nondeterministic because of atomicAdd usage + nondet_tol = 1e-5 + else: + nondet_tol = 0.0 + + input = torch.randn( + 2, 3, 8, 8, device=device, + dtype=torch.double).contiguous(memory_format=memory_format).requires_grad_() + gradcheck( + lambda x: F.interpolate(x, out_size, **kwargs), + [input], + check_forward_ad=check_forward_ad, nondet_tol=nondet_tol + ) + gradgradcheck( + lambda x: F.interpolate(x, out_size, **kwargs), + [input], + check_fwd_over_rev=check_forward_ad, nondet_tol=nondet_tol + ) + + # Assert that cpu and cuda give same results + if torch.device(device).type == 'xpu': + for shapes in [ + (2, 2, 3, 4), (2, 3, 4, 5), (3, 1, 2, 2), (1, 5, 3, 2) + ]: + a_xpu = torch.randn( + *shapes, device=device, dtype=torch.double + ).contiguous(memory_format=memory_format).requires_grad_() + a_cpu = a_xpu.detach().cpu().requires_grad_() + + with warnings.catch_warnings(record=True): + out_xpu = F.interpolate(a_xpu, scale_factor=scale_factor, **kwargs) + out_cpu = F.interpolate(a_cpu, scale_factor=scale_factor, **kwargs) + + self.assertEqual(out_cpu, out_xpu.cpu()) + + g_cuda = torch.randn_like(out_xpu) + g_cpu = g_cuda.cpu() + + out_xpu.backward(g_cuda) + out_cpu.backward(g_cpu) + + self.assertEqual(a_xpu.grad, a_cpu.grad) +TestNNDeviceType.test_upsamplingBiMode2d = upsamplingBiMode2d + instantiate_device_type_tests( TestNNDeviceType, globals(), only_for="xpu", allow_xpu=True ) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 09303c62d..fbb2dc73c 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -27,6 +27,7 @@ "view_as_real", "view_as_complex", "view", + "trace", "resize_", "resize_as_", "add", @@ -81,6 +82,8 @@ "nn.functional.threshold", "nn.functional.silu", "nn.functional.hardsigmoid", + "nn.functional.softplus", + "nn.functional.softshrink", "nonzero", "normal", "pow", @@ -100,6 +103,7 @@ "var", "var_mean", "tanh", + "hypot", "unfold", "uniform", "view", @@ -127,6 +131,7 @@ "nn.functional.unfold", "nn.functional.pad", "nn.functional.interpolate", + "nn.functional.upsample_bilinear", "nn.functional.upsample_nearest", # "nn.functional.nll_loss", # Lack of XPU implementation of aten::nll_loss2d_forward. Will retrieve the case, only if the op is implemented. "nn.functional.mse_loss", @@ -140,8 +145,13 @@ "addr", "cdist", "nn.functional.group_norm", + "nn.functional.batch_norm", + "native_batch_norm", + "_native_batch_norm_legit", + "_batch_norm_with_update", "bincount", "renorm", + "lerp", ] diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 90fc8287d..5b026c544 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -10,6 +10,8 @@ supported: - add_.Scalar - add.Scalar_out - _adaptive_avg_pool2d_backward + - adaptive_avg_pool2d.out + - _adaptive_avg_pool2d - cumsum - cumsum.out - cumsum_ @@ -65,6 +67,12 @@ supported: - le.Tensor - le.Tensor_out - le_.Tensor + - lerp.Tensor + - lerp.Tensor_out + - lerp_.Tensor + - lerp.Scalar + - lerp.Scalar_out + - lerp_.Scalar - gt.Scalar - gt.Scalar_out - gt_.Scalar @@ -91,6 +99,9 @@ supported: - gcd - gcd.out - gcd_ + - hypot + - hypot.out + - hypot_ - relu - relu_ - relu.out @@ -99,6 +110,14 @@ supported: - threshold.out - threshold_backward - threshold_backward.grad_input + - softplus + - softplus.out + - softplus_backward + - softplus_backward.grad_input + - softshrink + - softshrink.out + - softshrink_backward + - softshrink_backward.grad_input - gelu - gelu_ - gelu.out @@ -171,6 +190,8 @@ supported: - exp_ - empty.memory_format - empty_strided + - _efficientzerotensor + - complex.out - clone - fill_.Scalar - fill_.Tensor @@ -385,6 +406,26 @@ supported: - nll_loss_forward - nll_loss_backward.grad_input - nll_loss_backward + - batch_norm_stats + - batch_norm_elemt + - batch_norm_elemt.out + - batch_norm_backward_reduce + - batch_norm_backward_elemt + - batch_norm_update_stats + - native_batch_norm + - native_batch_norm.out + - native_batch_norm_backward + - _native_batch_norm_legit + - _native_batch_norm_legit.out + - _native_batch_norm_legit.no_stats + - _native_batch_norm_legit.no_stats_out + - _batch_norm_with_update + - _batch_norm_with_update.out + - batch_norm_backward + - upsample_bilinear2d + - upsample_bilinear2d.out + - upsample_bilinear2d_backward + - upsample_bilinear2d_backward.grad_input - _upsample_nearest_exact1d - _upsample_nearest_exact1d.out - upsample_nearest1d @@ -412,6 +453,7 @@ supported: - _cdist_forward - _pin_memory - is_pinned + - trace - reflection_pad2d - reflection_pad2d.out - reflection_pad2d_backward