From e210c5cee1922c820643967d94f39cc30325ff23 Mon Sep 17 00:00:00 2001 From: chunhuanMeng <105194461+chunhuanMeng@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:14:51 +0800 Subject: [PATCH 1/5] Enable aten::smooth_l1_loss forward/backward (#621) --- src/ATen/native/xpu/Loss.cpp | 58 +++++++++++++++++++ src/ATen/native/xpu/XPUFallback.template | 2 - .../native/xpu/sycl/BinaryMiscOpsKernels.cpp | 26 +++++++++ .../native/xpu/sycl/BinaryMiscOpsKernels.h | 2 + .../native/xpu/sycl/PointwiseOpsKernels.cpp | 34 +++++++++++ .../native/xpu/sycl/PointwiseOpsKernels.h | 2 + test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 3 + 8 files changed, 126 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/Loss.cpp b/src/ATen/native/xpu/Loss.cpp index f09f68b8a..050ff07b9 100644 --- a/src/ATen/native/xpu/Loss.cpp +++ b/src/ATen/native/xpu/Loss.cpp @@ -80,6 +80,64 @@ Tensor& XPUNativeFunctions::mse_loss_backward_out( return grad_input; } + +Tensor& XPUNativeFunctions::smooth_l1_loss_out( + const Tensor& input, + const Tensor& target, + int64_t reduction, + double beta, + Tensor& result) { + if (reduction != Reduction::None) { + TORCH_INTERNAL_ASSERT( + reduction == Reduction::Mean || reduction == Reduction::Sum); + result.resize_({}); + Tensor loss; + auto iter = TensorIterator::borrowing_binary_op(loss, input, target); + native::xpu::smooth_l1_kernel(iter, beta); + if (reduction == Reduction::Mean) { + at::mean_out(const_cast(result), iter.output(), IntArrayRef{}); + } else { + at::sum_out(const_cast(result), iter.output(), IntArrayRef{}); + } + } else { + auto iter = TensorIterator::borrowing_binary_op(result, input, target); + native::xpu::smooth_l1_kernel(iter, beta); + } + return result; +} + +Tensor XPUNativeFunctions::smooth_l1_loss( + const Tensor& input, + const Tensor& target, + int64_t reduction, + double beta) { + Tensor result = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + result = XPUNativeFunctions::smooth_l1_loss_out( + input, target, reduction, beta, result); + return result; +} + +Tensor& XPUNativeFunctions::smooth_l1_loss_backward_out( + const Tensor& grad_output, + const Tensor& input, + const Tensor& target, + int64_t reduction, + double beta, + Tensor& grad_input) { + auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.; + auto iter = at::TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(input) + .add_const_input(target) + .add_const_input(grad_output) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true) + .build(); + native::xpu::smooth_l1_backward_kernel(iter, norm, beta); + return grad_input; +} + Tensor XPUNativeFunctions::binary_cross_entropy( const Tensor& self, const Tensor& target, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 3f2064653..93321f23d 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -257,8 +257,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "signbit.out", "sign.out", "sinc.out", - "smooth_l1_loss_backward.grad_input", - "smooth_l1_loss.out", "special_airy_ai.out", "special_bessel_j0.out", "special_bessel_j1.out", diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp index 00c5398af..5ac71c163 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp @@ -23,6 +23,32 @@ void mse_kernel(TensorIteratorBase& iter) { [&]() { gpu_kernel(iter, MSEFunctor()); }); } +template +struct SmoothL1Functor { + scalar_t operator()(scalar_t input, scalar_t target) const { + auto z = std::abs(input - target); + return z < beta_val ? scalar_t(0.5) * z * z / beta_val + : z - scalar_t(0.5) * beta_val; + } + SmoothL1Functor(scalar_t beta_val) : beta_val(beta_val) {} + + private: + scalar_t beta_val; +}; + +void smooth_l1_kernel(TensorIteratorBase& iter, double beta) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "smooth_l1_xpu", + [&iter, beta]() { + scalar_t beta_val(beta); + SmoothL1Functor f(beta_val); + gpu_kernel(iter, f); + }); +} + template struct HuberFunctor { scalar_t operator()(scalar_t a, scalar_t b) const { diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h index 94cfb7c90..17672ec29 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h @@ -6,6 +6,8 @@ namespace at::native::xpu { void mse_kernel(TensorIteratorBase& iter); +void smooth_l1_kernel(TensorIteratorBase& iter, double beta); + void huber_kernel(TensorIterator& iter, double delta); } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp index 5dc06e25a..822a83e99 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.cpp @@ -125,6 +125,40 @@ void mse_backward_kernel(TensorIterator& iter, const Scalar& value) { }); } +template +struct SmoothL1BackwardFunctor { + scalar_t operator()(scalar_t input, scalar_t target, scalar_t grad_output) + const { + const auto x = input - target; + if (x < -beta_val) + return -norm_val * grad_output; + else if (x > beta_val) + return norm_val * grad_output; + else + return norm_val * x * grad_output / beta_val; + } + SmoothL1BackwardFunctor(scalar_t norm_val, scalar_t beta_val) + : norm_val(norm_val), beta_val(beta_val) {} + + private: + scalar_t norm_val; + scalar_t beta_val; +}; + +void smooth_l1_backward_kernel(TensorIterator& iter, Scalar norm, double beta) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, + kBFloat16, + iter.dtype(), + "smooth_l1_backward_xpu", + [&iter, &norm, beta] { + auto norm_val = norm.to(); + scalar_t beta_val(beta); + SmoothL1BackwardFunctor f(norm_val, beta_val); + gpu_kernel(iter, f); + }); +} + template struct HuberBackwardFunctor { scalar_t operator()(scalar_t input, scalar_t target, scalar_t grad_output) diff --git a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h index 586a64f3c..613c3cca6 100644 --- a/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h +++ b/src/ATen/native/xpu/sycl/PointwiseOpsKernels.h @@ -10,6 +10,8 @@ void addcdiv_kernel(TensorIterator& iter, Scalar value); void mse_backward_kernel(TensorIterator& iter, const Scalar& value); +void smooth_l1_backward_kernel(TensorIterator& iter, Scalar norm, double beta); + void huber_backward_kernel( TensorIterator& iter, const Scalar& norm, diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 2a36ddfb0..c281747f2 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -179,6 +179,7 @@ "nn.functional.upsample_bilinear", "nn.functional.upsample_nearest", "nn.functional.nll_loss", + "nn.functional.smooth_l1_loss", "nn.functional.mse_loss", "nn.functional.binary_cross_entropy", "nn.functional.huber_loss", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 77d77d4f0..fd087c7bc 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -308,6 +308,9 @@ supported: - bitwise_and.Tensor_out - bitwise_or.Tensor_out - bitwise_xor.Tensor_out + - smooth_l1_loss + - smooth_l1_loss.out + - smooth_l1_loss_backward.grad_input - bitwise_not.out - where.self_out - where.self From 36dfe230dea6a737fe260b072276cbcca3ca3f9a Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Tue, 30 Jul 2024 08:40:20 +0800 Subject: [PATCH 2/5] Add aten::polar and its variants (#606) Co-authored-by: yucai Co-authored-by: Feng Yuan --- src/ATen/native/xpu/TensorFactories.cpp | 15 +++++++++++++++ src/ATen/native/xpu/XPUFallback.template | 1 - src/ATen/native/xpu/sycl/ComplexKernels.cpp | 14 ++++++++++++++ src/ATen/native/xpu/sycl/ComplexKernels.h | 2 ++ test/xpu/extended/run_test_with_skip.py | 4 ++++ test/xpu/run_test_with_skip.py | 12 ++++++++++-- test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 1 + 8 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/TensorFactories.cpp b/src/ATen/native/xpu/TensorFactories.cpp index 110590958..44da487f7 100644 --- a/src/ATen/native/xpu/TensorFactories.cpp +++ b/src/ATen/native/xpu/TensorFactories.cpp @@ -151,6 +151,21 @@ Tensor& XPUNativeFunctions::complex_out( return result; } +Tensor& XPUNativeFunctions::polar_out( + const Tensor& abs, + const Tensor& angle, + Tensor& result) { + complex_check_dtype(result, abs, angle); + auto iter = TensorIteratorConfig() + .add_output(result) + .add_const_input(abs) + .add_const_input(angle) + .check_all_same_dtype(false) + .build(); + native::xpu::polar_kernel(iter); + return result; +} + Tensor& XPUNativeFunctions::randperm_out( int64_t n, c10::optional generator, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 93321f23d..4a4c96828 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -240,7 +240,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "ormqr", "_pdist_backward", "_pdist_forward", - "polar.out", "_prelu_kernel", "_prelu_kernel_backward", "prod", diff --git a/src/ATen/native/xpu/sycl/ComplexKernels.cpp b/src/ATen/native/xpu/sycl/ComplexKernels.cpp index 56b25d0ef..87504bd5e 100644 --- a/src/ATen/native/xpu/sycl/ComplexKernels.cpp +++ b/src/ATen/native/xpu/sycl/ComplexKernels.cpp @@ -21,4 +21,18 @@ void complex_kernel(TensorIterator& iter) { }); } +template +struct PolarFunctor { + c10::complex operator()(scalar_t a, scalar_t b) const { + return c10::complex(a * std::cos(b), a * std::sin(b)); + } +}; + +void polar_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(0), "polar_xpu", [&]() { + PolarFunctor f; + gpu_kernel(iter, f); + }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ComplexKernels.h b/src/ATen/native/xpu/sycl/ComplexKernels.h index 990bcd14e..d51556b4f 100644 --- a/src/ATen/native/xpu/sycl/ComplexKernels.h +++ b/src/ATen/native/xpu/sycl/ComplexKernels.h @@ -6,4 +6,6 @@ namespace at::native::xpu { void complex_kernel(TensorIterator& iter); +void polar_kernel(TensorIterator& iter); + } // namespace at::native::xpu diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index 6f8fe8d3a..a75d2e675 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -154,6 +154,10 @@ # Greatest relative difference: 0.00396728515625 at index (610,) (up to 0.001 allowed) "test_compare_cpu_hypot_xpu_bfloat16", + # RuntimeError: Expected both inputs to be Half, Float or Double tensors but got BFloat16 and BFloat16. + # Polar's backward is calculated using complex(), which does not support bfloat16. CUDA fails with same error. + "test_compare_cpu_polar_xpu_bfloat16", + # Regressions due to PyTorch uplift (Numeric difference in float and bfloat) # https://github.com/intel/torch-xpu-ops/issues/549 # Example fail log diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 719af3ca4..7d051607e 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -782,6 +782,10 @@ def launch_test(test_case, skip_list=None, exe_list=None): # torch.complex32 - "sinh_cpu" not implemented for 'ComplexHalf' "test_dtypes_cosh_xpu", + # RuntimeError: Expected both inputs to be Half, Float or Double tensors but got BFloat16 and BFloat16. + # Polar's backward is calculated using complex(), which does not support bfloat16. CUDA fails with same error. + "test_dtypes_polar_xpu", + # implemented aten::histogram to align MPS operators coverage, CUDA doesn't support # but test_dtypes infrastructure leverage CUDA supported datatypes "test_dtypes_histogram_xpu", @@ -3016,8 +3020,12 @@ def launch_test(test_case, skip_list=None, exe_list=None): res += launch_test("nn/test_load_state_dict_xpu.py") # test_module_hooks - -res += launch_test("nn/test_module_hooks_xpu.py") +skip_list = ( + # TypeError: TestStateDictHooks.test_register_state_dict_post_hook() missing 1 required positional argument: 'private' + # https://github.com/intel/torch-xpu-ops/issues/658 + "test_register_state_dict_post_hook", +) +res += launch_test("nn/test_module_hooks_xpu.py", skip_list) # test_parametrization diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index c281747f2..823988488 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -208,6 +208,7 @@ "unique", "multinomial", "lerp", + "polar", "frac", "aminmax", "argmin", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index fd087c7bc..9d453d215 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -268,6 +268,7 @@ supported: - eye.m_out - _efficientzerotensor - complex.out + - polar.out - clone - fill_.Scalar - fill_.Tensor From 82268376d9f215f5ca0988264608ead45ae028f9 Mon Sep 17 00:00:00 2001 From: mengfei25 Date: Tue, 30 Jul 2024 13:19:00 +0800 Subject: [PATCH 3/5] Enable weekly test (#637) 1. enable weekly test contains 3 suites full e2e and full ut 2. always() to not cancelled() --- .github/workflows/nightly_ondemand.yml | 59 +++++++++++++++++++++----- .github/workflows/pull.yml | 4 +- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/.github/workflows/nightly_ondemand.yml b/.github/workflows/nightly_ondemand.yml index 1a663661e..e5be18fb5 100644 --- a/.github/workflows/nightly_ondemand.yml +++ b/.github/workflows/nightly_ondemand.yml @@ -2,8 +2,10 @@ name: Nightly-OnDemand Tests on: schedule: - # GMT+8 21:00 every day - - cron: '0 13 * * *' + # GMT+8 21:00 every workday + - cron: '0 13 * * 0-4' + # GMT+8 0:00 Saturday + - cron: '0 16 * * 5' workflow_dispatch: inputs: pytorch: @@ -78,7 +80,7 @@ jobs: runs-on: pvc_e2e # Don't run on forked repos if: github.repository_owner == 'intel' - timeout-minutes: 900 + timeout-minutes: 3600 env: pytorch: ${{ github.event_name == 'schedule' && 'main' || inputs.pytorch }} keep_torch_xpu_ops: ${{ github.event_name == 'schedule' && 'false' || inputs.keep_torch_xpu_ops }} @@ -174,8 +176,10 @@ jobs: echo "$GITHUB_ENV" rm -rf ../pytorch/inductor_log rm -rf /tmp/torchinductor_* + + # Nihglty launch - name: Nightly Huggingface FP32/BF16/FP16 Inference & Training Accuracy Test - if: github.event_name == 'schedule' + if: github.event_name == 'schedule' && github.event.schedule == '0 13 * * 0-4' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: huggingface @@ -185,7 +189,7 @@ jobs: scenario: accuracy hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: Nightly Torchbench BF16 Training Accuracy Test - if: github.event_name == 'schedule' + if: github.event_name == 'schedule' && github.event.schedule == '0 13 * * 0-4' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: torchbench @@ -195,7 +199,7 @@ jobs: env_prepare: true hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: Nightly Timm_models FP16 Training Accuracy Test - if: github.event_name == 'schedule' + if: github.event_name == 'schedule' && github.event.schedule == '0 13 * * 0-4' uses: ./.github/actions/inductor-xpu-e2e-test with: suite: timm_models @@ -204,6 +208,38 @@ jobs: scenario: accuracy env_prepare: true hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + # Weekly launch + - name: Weekly Huggingface Full Test + if: github.event_name == 'schedule' && github.event.schedule == '0 16 * * 5' + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: huggingface + env_prepare: true + dt: float32,bfloat16,float16,amp_bf16,amp_fp16 + mode: inference,training + scenario: accuracy,performance + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Weekly Torchbench Full Test + if: github.event_name == 'schedule' && github.event.schedule == '0 16 * * 5' + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: torchbench + env_prepare: true + dt: float32,bfloat16,float16,amp_bf16,amp_fp16 + mode: inference,training + scenario: accuracy,performance + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + - name: Weekly Timm_models Full Test + if: github.event_name == 'schedule' && github.event.schedule == '0 16 * * 5' + uses: ./.github/actions/inductor-xpu-e2e-test + with: + suite: timm_models + env_prepare: true + dt: float32,bfloat16,float16,amp_bf16,amp_fp16 + mode: inference,training + scenario: accuracy,performance + hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + # On-demand launch - name: OnDemand Test (${{ inputs.suite }} ${{ inputs.dt }} ${{ inputs.mode }} ${{ inputs.scenario }}) if: github.event_name != 'schedule' uses: ./.github/actions/inductor-xpu-e2e-test @@ -216,7 +252,7 @@ jobs: hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: Summarize archieve files id: summary - if: always() + if: ${{ ! cancelled() }} run: | rm -rf ${{ github.workspace }}/upload_files cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files @@ -237,14 +273,14 @@ jobs: exit 1 fi - name: Upload Inductor XPU E2E Data - if: always() + if: ${{ ! cancelled() }} uses: actions/upload-artifact@v4 with: name: Inductor-XPU-E2E-Data-${{ github.event.pull_request.number || github.sha }} path: ${{ github.workspace }}/upload_files - + Tests-Failure-And-Report: - if: always() + if: ${{ ! cancelled() }} runs-on: pvc_e2e permissions: issues: write @@ -288,6 +324,9 @@ jobs: test_type="On-demand" test_issue_id=426 cc_comment="CC @${GITHUB_TRIGGERING_ACTOR}" + elif [ "${{ github.event.schedule }}" == "0 16 * * 5" ];then + test_type="Weekly" + test_issue_id=432 else test_type="Nightly" test_issue_id=432 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 5ee55fdcf..350c88f91 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -127,7 +127,7 @@ jobs: env_prepare: true hf_token: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - name: Summarize archieve files - if: always() + if: ${{ ! cancelled() }} run: | rm -rf ${{ github.workspace }}/upload_files cp -r ${{ github.workspace }}/../pytorch/inductor_log ${{ github.workspace }}/upload_files @@ -137,7 +137,7 @@ jobs: exit 1 fi - name: Upload Inductor XPU E2E Data - if: always() + if: ${{ ! cancelled() }} uses: actions/upload-artifact@v4 with: name: Inductor-XPU-E2E-Data-${{ github.event.pull_request.number || github.sha }} From 67116b3c359c619f0fdb9b0266f4fbd6607cc956 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Tue, 30 Jul 2024 13:55:12 +0800 Subject: [PATCH 4/5] Add aten::masked_scatter_ (#652) Add aten::masked_scatter_. --------- Co-authored-by: Feng Yuan --- src/ATen/native/xpu/Indexing.cpp | 30 +++++++ src/ATen/native/xpu/XPUFallback.template | 1 - src/ATen/native/xpu/sycl/Indexing.cpp | 90 +++++++++++++++++++ src/ATen/native/xpu/sycl/IndexingKernels.h | 6 ++ src/ATen/native/xpu/sycl/pstl/PSTLFunctions.h | 13 ++- test/xpu/test_torch_xpu.py | 4 +- test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 1 + 8 files changed, 136 insertions(+), 10 deletions(-) diff --git a/src/ATen/native/xpu/Indexing.cpp b/src/ATen/native/xpu/Indexing.cpp index 5db9a7238..d4d5598e6 100644 --- a/src/ATen/native/xpu/Indexing.cpp +++ b/src/ATen/native/xpu/Indexing.cpp @@ -45,6 +45,36 @@ Tensor XPUNativeFunctions::index_select( return index_select_out(self, dim, index, out); } +Tensor& XPUNativeFunctions::masked_scatter_( + Tensor& self, + const Tensor& mask, + const Tensor& source) { + at::assert_no_internal_overlap(self); + TORCH_CHECK( + self.scalar_type() == source.scalar_type(), + "masked_scatter_: expected self and source to have same dtypes but got ", + self.scalar_type(), + " and ", + source.scalar_type()); + TORCH_CHECK( + mask.dtype() == ScalarType::Bool, + "masked_scatter_ only supports boolean masks, " + "but got mask with dtype ", + mask.dtype()); + + c10::MaybeOwned b_mask = + expand_inplace(self, mask, "masked_scatter_"); + + if (self.numel() == 0) { + return self; + } + + auto maskPrefixSum = at::empty(self.sizes(), mask.options().dtype(kLong)); + native::xpu::masked_scatter_kernel(self, *b_mask, maskPrefixSum, source); + + return self; +} + static Tensor& masked_select_out_impl( Tensor& result, const Tensor& self, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 4a4c96828..23b79d6dd 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -224,7 +224,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "log_normal_", "logspace.out", "lu_unpack.out", - "masked_scatter_", "max_pool3d_with_indices", "max_pool3d_with_indices_backward", "max_unpool2d", diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 312f4ce12..fcee372d1 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -597,6 +597,7 @@ void index_put_deterministic_kernel( if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) { auto expanded_size = at::DimVector(expandedValue.sizes()); + auto size1 = expandedValue.sizes(); auto size2 = linearIndex.sizes(); if (are_expandable(size1, size2)) { @@ -667,6 +668,95 @@ void index_put_deterministic_kernel( } } +template +struct MaskedScatterElementwiseFunctor { + scalar_t operator()( + const scalar_t a, + const bool mask, + const int64_t maskPrefixSum) const { + if (mask) { + return source_ptr_[maskPrefixSum]; + } + return a; + } + MaskedScatterElementwiseFunctor(const scalar_t* source_ptr) + : source_ptr_(source_ptr) {} + + private: + const scalar_t* source_ptr_; +}; + +struct MaskedScatterSizeCheckFunctor { + void operator()(sycl::nd_item<1> item) const { + const auto totalElements = *mask_exclusive_sum_ + *mask_; + SYCL_KERNEL_ASSERT(totalElements <= srcSize_); + } + MaskedScatterSizeCheckFunctor( + const int64_t* const mask_exclusive_sum, + const bool* const mask, + const int64_t srcSize) + : mask_exclusive_sum_(mask_exclusive_sum), + mask_(mask), + srcSize_(srcSize) {} + + private: + const int64_t* const mask_exclusive_sum_; + const bool* const mask_; + const int64_t srcSize_; +}; + +void masked_scatter_kernel( + const TensorBase& self, + const TensorBase& mask, + const TensorBase& maskPrefixSum, + const TensorBase& source) { + const auto srcSize = source.numel(); + const auto mask_cont = mask.contiguous(); + const auto mask_numel = mask.numel(); + + // Use a prefix sum to determine the output locations of the masked elements + auto maskPrefixSum_data = maskPrefixSum.mutable_data_ptr(); + auto mask_data = mask_cont.const_data_ptr(); + + pstl::exclusive_scan( + mask_data, + mask_data + mask_numel, + maskPrefixSum_data, + static_cast(0)); + + // Asynchronously check that the number of `1` elements present in the mask + // must be <= the number of elements available in `src`. + auto caller = MaskedScatterSizeCheckFunctor( + &maskPrefixSum_data[mask_numel - 1], &mask_data[mask_numel - 1], srcSize); + sycl_kernel_submit((size_t)1, (size_t)1, getCurrentSYCLQueue(), caller); + + // We are getting elements from `src` based on an offset from + // `maskPrefixSum`, so that should be made contiguous too + auto source_contig = source.contiguous(); + + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) + .check_all_same_dtype(false) + .resize_outputs(false) + .add_output(self) + .add_input(self) + .add_const_input(mask_cont) + .add_input(maskPrefixSum) + .build(); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + ScalarType::Bool, + ScalarType::BFloat16, + ScalarType::Half, + self.scalar_type(), + "masked_scatter_", + [&]() { + auto source_ptr = source_contig.const_data_ptr(); + gpu_kernel(iter, MaskedScatterElementwiseFunctor(source_ptr)); + }); +} + } // namespace at::native::xpu + #pragma GCC diagnostic pop #pragma clang diagnostic pop diff --git a/src/ATen/native/xpu/sycl/IndexingKernels.h b/src/ATen/native/xpu/sycl/IndexingKernels.h index cde537e73..8f32f49f9 100644 --- a/src/ATen/native/xpu/sycl/IndexingKernels.h +++ b/src/ATen/native/xpu/sycl/IndexingKernels.h @@ -47,4 +47,10 @@ void index_put_deterministic_kernel( bool accumulate, bool unsafe); +void masked_scatter_kernel( + const TensorBase& self, + const TensorBase& mask, + const TensorBase& maskPrefixSum, + const TensorBase& source); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/pstl/PSTLFunctions.h b/src/ATen/native/xpu/sycl/pstl/PSTLFunctions.h index 66c316e25..efc9f164e 100644 --- a/src/ATen/native/xpu/sycl/pstl/PSTLFunctions.h +++ b/src/ATen/native/xpu/sycl/pstl/PSTLFunctions.h @@ -1,11 +1,10 @@ #pragma once #include -#include - #include #include #include +#include #include #include #include @@ -23,10 +22,10 @@ struct KSScanKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { // initialize local_input auto cur_init = init_; if (scan_type == 1) { - local_scan_[local_id] = first_[local_id]; + local_scan_[local_id] = c10::load(&first_[local_id]); } else { if (local_id > 0) - local_scan_[local_id] = first_[local_id - 1]; + local_scan_[local_id] = c10::load(&first_[local_id - 1]); else local_scan_[local_id] = 0; } @@ -72,17 +71,17 @@ struct KSScanWithCarrierKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { auto cur_init = (group_id == 0 ? init_ : 0); if (global_id < N_) { if (scan_type == 1) { - local_scan_[local_id] = first_[global_id]; + local_scan_[local_id] = c10::load(&first_[global_id]); } else { if (local_id > 0) - local_scan_[local_id] = first_[global_id - 1]; + local_scan_[local_id] = c10::load(&first_[global_id - 1]); else local_scan_[local_id] = 0; } if (local_id == 0) local_scan_[local_id] += cur_init; if (local_id == wgroup_size_ - 1) { - carry_ptr_[group_id] = first_[global_id]; + carry_ptr_[group_id] = c10::load(&first_[global_id]); } } item_id.barrier(sycl_local_fence); diff --git a/test/xpu/test_torch_xpu.py b/test/xpu/test_torch_xpu.py index b82a8ec67..80fb3c8b0 100644 --- a/test/xpu/test_torch_xpu.py +++ b/test/xpu/test_torch_xpu.py @@ -3995,11 +3995,11 @@ def test_masked_scatter(self, device, dtype): dest_ones.masked_scatter_(mask, src_ones) self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0) - # Bound checking in CUDA is done inside a kernel + # Bound checking in GPU is done inside a kernel # in order to avoid synchronization, but this means # we can not clear the failures. So there is no way # to test it then recover. - if self.device_type != 'cuda' or self.device_type != 'xpu': + if self.device_type != 'cuda' and self.device_type != 'xpu': # make src smaller. this should fail src = torch.zeros(num_copy - 1, dtype=dt, device=device) with self.assertRaises(RuntimeError): diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 823988488..635b1f1b6 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -80,6 +80,7 @@ "index_fill", "index_put", "index_select", + "masked_scatter", "masked_select", "isin", "isnan", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 9d453d215..5c354e0b0 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -105,6 +105,7 @@ supported: - isnan.out - masked_fill_.Tensor - masked_fill_.Scalar + - masked_scatter_ - index_add.out - index_add_ - index_add From 8a821bf6aaa72f733ba65f6b2cbef40983bd71a9 Mon Sep 17 00:00:00 2001 From: "Huaiyu, Zheng" Date: Tue, 30 Jul 2024 14:37:25 +0800 Subject: [PATCH 5/5] Add aten::_weight_norm_interface and aten::_weight_norm_interface_backward (#607) add aten::_weight_norm_interface and aten::_weight_norm_interface_backward --- src/ATen/native/xpu/WeightNorm.cpp | 27 + .../xpu/sycl/AdaptiveMaxPooling2dKernels.cpp | 8 +- src/ATen/native/xpu/sycl/BatchKernel.h | 237 ++-- src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp | 13 +- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 5 +- src/ATen/native/xpu/sycl/Indexing.h | 14 +- .../native/xpu/sycl/WeightNormKernels.cpp | 1027 +++++++++++++++++ src/ATen/native/xpu/sycl/WeightNormKernels.h | 16 + yaml/xpu_functions.yaml | 2 + 9 files changed, 1234 insertions(+), 115 deletions(-) create mode 100644 src/ATen/native/xpu/WeightNorm.cpp create mode 100644 src/ATen/native/xpu/sycl/WeightNormKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/WeightNormKernels.h diff --git a/src/ATen/native/xpu/WeightNorm.cpp b/src/ATen/native/xpu/WeightNorm.cpp new file mode 100644 index 000000000..7fec9ecfe --- /dev/null +++ b/src/ATen/native/xpu/WeightNorm.cpp @@ -0,0 +1,27 @@ +#include +#include +namespace at { +std::tuple XPUNativeFunctions::_weight_norm_interface( + const Tensor& v, + const Tensor& g, + int64_t dim) { + return native::xpu::weight_norm_kernel(v, g, dim); +} + +std::tuple XPUNativeFunctions::_weight_norm_interface_backward( + const Tensor& grad_w, + const Tensor& saved_v, + const Tensor& saved_g, + const Tensor& saved_norms, + int64_t dim) { + TORCH_CHECK(saved_v.is_contiguous(), "saved_v must be contiguous"); + TORCH_CHECK(saved_g.is_contiguous(), "saved_g must be contiguous"); + TORCH_CHECK(saved_norms.is_contiguous(), "saved_norms must be contiguous"); + TORCH_CHECK( + dim == 0 || dim == saved_v.dim() - 1, + "fused kernels can only be applied for first or last dim") + + return native::xpu::weight_norm_backward_kernel( + grad_w, saved_v, saved_g, saved_norms, dim); +} +} // namespace at \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/AdaptiveMaxPooling2dKernels.cpp b/src/ATen/native/xpu/sycl/AdaptiveMaxPooling2dKernels.cpp index fe50d1a6c..9da231b0d 100644 --- a/src/ATen/native/xpu/sycl/AdaptiveMaxPooling2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AdaptiveMaxPooling2dKernels.cpp @@ -124,8 +124,8 @@ void launch_adaptive_max_pool2d_kernel( using KernelClass = AdaptiveMaxPool2dKernelFunctor; int64_t output_size = batch * plane * osizeH * osizeW; - BatchKernelConfig cfg = { - 1, output_size, 1, 1, true, BatchKernelConfig::Policy::pAdaptive}; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + 1, output_size, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); cfg.build(); @@ -301,8 +301,8 @@ void launch_adaptive_max_pool2d_backward_kernel( int64_t sizeP) { using KernelClass = AdaptiveMaxPool2dBackwardKernelFunctor; - BatchKernelConfig cfg = { - 1, osize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive}; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + 1, osize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); cfg.build(); diff --git a/src/ATen/native/xpu/sycl/BatchKernel.h b/src/ATen/native/xpu/sycl/BatchKernel.h index cff967a76..bc0da3bb5 100644 --- a/src/ATen/native/xpu/sycl/BatchKernel.h +++ b/src/ATen/native/xpu/sycl/BatchKernel.h @@ -25,6 +25,146 @@ class BatchKernelConfig { } public: + template + static BatchKernelConfig make_config( + int64_t batch, + int64_t problem, + int64_t stride, + int64_t problem_batch, + bool problem_along_x, + Policy policy = Policy::pSegment, + int64_t prefer_wg_size = 0) { + BatchKernelConfig cfg = { + batch, + problem, + stride, + problem_batch, + problem_along_x, + policy, + prefer_wg_size}; + cfg.template build(); + + return cfg; + } + + template + static BatchKernelConfig make_config( + int64_t batch, + int64_t problem, + int64_t stride, + int64_t problem_batch, + bool problem_along_x, + std::vector policies, + int64_t prefer_wg_size = 0) { + BatchKernelConfig cfg = { + batch, + problem, + stride, + problem_batch, + problem_along_x, + policies, + prefer_wg_size}; + cfg.template build(); + + return cfg; + } + + sycl::range<2> global_size() const { + return {glb_range_y_, glb_range_x_}; + } + + sycl::range<2> group_size() const { + return {wg_range_y_, wg_range_x_}; + } + + struct ItemDesc { + /* chunk id along problem dim */ size_t chunk; + /* problem chunk size */ size_t chunk_size; + /* offsite in current chunk */ size_t chunk_off; + /* how many active chunks along problem dim */ size_t chunk_num; + /* global batch id */ size_t glb_batch; + /* global problem id */ size_t glb_problem; + }; + + ItemDesc get_item_desc(sycl::nd_item<2> item) const { + auto lix = item.get_local_id(1); + auto liy = item.get_local_id(0); + auto lrx = item.get_local_range(1); + auto lry = item.get_local_range(0); + auto wgrx = item.get_group_range(1); + auto wgry = item.get_group_range(0); + auto gix = item.get_global_id(1); + auto giy = item.get_global_id(0); + auto gx = item.get_group(1); + auto gy = item.get_group(0); + + // ItemDesc::glb_problem is meaningless, if policy is loop for all. + if (problem_along_x_) { + return {gx, lrx, lix, wgrx, giy, gix}; + } else { + return {gy, lry, liy, wgry, gix, giy}; + } + } + + // iterate over problems and batchs for `pAdaptive` policy + // # update workload status inplace in `desc`. + // # prioritize problem iteration. + bool next(sycl::nd_item<2> item, ItemDesc& desc) const { + auto next_problem = desc.glb_problem + problem_glb_range_; + auto next_batch = desc.glb_batch + batch_glb_range_; + auto cur_chunk = desc.chunk; + + // WA: break deduce chain, or offline compiler gets crash, due to, + // massive and deep divergence level + desc = get_item_desc(item); + + // iterate over problem + if (next_problem < problem_range_) { + desc.glb_problem = next_problem; + desc.chunk = cur_chunk + desc.chunk_num; + return true; + } + + // iterate over batch + if (next_batch < batch_range_) { + desc.glb_batch = next_batch; + return true; + } + + return false; + } + + static Policy suggest_policy( + int64_t batch, + int64_t problem, + int64_t stride, + bool problem_along_x, + bool bypass_adaptive_policy = true) { + auto target_wi_num = syclMaxWorkItemsPerTile(); + + if (!bypass_adaptive_policy && batch * problem * stride >= target_wi_num) { + return Policy::pAdaptive; + } + // Using device max work group size to deduce range configuration + // approximately. + BatchKernelConfig cfg_ = make_config( + batch, + problem, + stride, + batch * stride, + problem_along_x, + Policy::pLoop, + syclDeviceMaxWorkGroupSize()); + size_t wg_num = (cfg_.glb_range_x_ / cfg_.wg_range_x_) * + (cfg_.glb_range_y_ / cfg_.wg_range_y_); + size_t wg_size = cfg_.wg_range_x_ * cfg_.wg_range_y_; + if (wg_size * (wg_num + 1) > target_wi_num) { + return Policy::pLoop; + } + + return Policy::pSegment; + } + BatchKernelConfig( int64_t batch, int64_t problem, @@ -52,11 +192,14 @@ class BatchKernelConfig { template void build() { - size_t wg_size = syclMaxWorkGroupSize(); + size_t wg_size; size_t sg_size = syclMaxSubGroupSize(); + // Caller takes responsibility of if work group size is valid or compatible. if (prefer_wg_size_ != 0 && prefer_wg_size_ % sg_size == 0 && - prefer_wg_size_ < wg_size) { + prefer_wg_size_ <= syclDeviceMaxWorkGroupSize()) { wg_size = prefer_wg_size_; + } else { + wg_size = syclMaxWorkGroupSize(); } wg_range_x_ = sg_size; wg_range_y_ = wg_size / wg_range_x_; @@ -170,96 +313,6 @@ class BatchKernelConfig { }(), prefer_wg_size) {} - sycl::range<2> global_size() const { - return {glb_range_y_, glb_range_x_}; - } - - sycl::range<2> group_size() const { - return {wg_range_y_, wg_range_x_}; - } - - struct ItemDesc { - /* chunk id along problem dim */ size_t chunk; - /* problem chunk size */ size_t chunk_size; - /* offsite in current chunk */ size_t chunk_off; - /* how many active chunks along problem dim */ size_t chunk_num; - /* global batch id */ size_t glb_batch; - /* global problem id */ size_t glb_problem; - }; - - ItemDesc get_item_desc(sycl::nd_item<2> item) const { - auto lix = item.get_local_id(1); - auto liy = item.get_local_id(0); - auto lrx = item.get_local_range(1); - auto lry = item.get_local_range(0); - auto wgrx = item.get_group_range(1); - auto wgry = item.get_group_range(0); - auto gix = item.get_global_id(1); - auto giy = item.get_global_id(0); - auto gx = item.get_group(1); - auto gy = item.get_group(0); - - // ItemDesc::glb_problem is meaningless, if policy is loop for all. - if (problem_along_x_) { - return {gx, lrx, lix, wgrx, giy, gix}; - } else { - return {gy, lry, liy, wgry, gix, giy}; - } - } - - // iterate over problems and batchs for `pAdaptive` policy - // # update workload status inplace in `desc`. - // # prioritize problem iteration. - bool next(sycl::nd_item<2> item, ItemDesc& desc) const { - auto next_problem = desc.glb_problem + problem_glb_range_; - auto next_batch = desc.glb_batch + batch_glb_range_; - auto cur_chunk = desc.chunk; - - // WA: break deduce chain, or offline compiler gets crash, due to, - // massive and deep divergence level - desc = get_item_desc(item); - - // iterate over problem - if (next_problem < problem_range_) { - desc.glb_problem = next_problem; - desc.chunk = cur_chunk + desc.chunk_num; - return true; - } - - // iterate over batch - if (next_batch < batch_range_) { - desc.glb_batch = next_batch; - return true; - } - - return false; - } - - static Policy suggest_policy( - int64_t batch, - int64_t problem, - int64_t stride, - bool problem_along_x, - bool bypass_adaptive_policy = true) { - auto target_wi_num = syclMaxWorkItemsPerTile(); - - if (!bypass_adaptive_policy && batch * problem * stride >= target_wi_num) { - return Policy::pAdaptive; - } - - BatchKernelConfig cfg_ = { - batch, problem, stride, batch * stride, problem_along_x, Policy::pLoop}; - size_t wg_num = (cfg_.glb_range_x_ / cfg_.wg_range_x_) * - (cfg_.glb_range_y_ / cfg_.wg_range_y_); - size_t wg_size = cfg_.wg_range_x_ * cfg_.wg_range_y_; - - if (wg_size * (wg_num + 1) > target_wi_num) { - return Policy::pLoop; - } - - return Policy::pSegment; - } - public: /* logical shape desc */ int64_t batch_; /* logical shape desc */ int64_t problem_; diff --git a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp index 7fd51ddf3..eab6f4c48 100644 --- a/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp +++ b/src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp @@ -371,9 +371,8 @@ void launch_max_pool2d_kernel( auto& queue = at::xpu::getCurrentSYCLQueue(); int outputSize = numBatch * numPlane * outputSizeH * outputSizeW; int stride = numPlane * outputSizeH * outputSizeW; - BatchKernelConfig cfg = { - 1, outputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive}; - cfg.template build(); + BatchKernelConfig cfg = BatchKernelConfig::make_config( + 1, outputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); auto kfn = KernelClass( output, indices, @@ -436,8 +435,8 @@ void launch_max_pool2d_backward_kernel( using KernelClass = MaxPool2dBackwardDeterministicKernelFunctor; - BatchKernelConfig cfg = { - 1, gradInputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive}; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + 1, gradInputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); cfg.template build(); auto kfn = KernelClass( gradInput, @@ -468,8 +467,8 @@ void launch_max_pool2d_backward_kernel( numBatch * numPlane * gradOutputSizeH * gradOutputSizeW; using KernelClass = MaxPool2dBackwardKernelFunctor; - BatchKernelConfig cfg = { - 1, gradOutputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive}; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + 1, gradOutputSize, 1, 1, true, BatchKernelConfig::Policy::pAdaptive); cfg.template build(); auto kfn = KernelClass( gradInput, diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index 5b9e4fa6b..0c7338d5d 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -57,9 +57,8 @@ void embedding_bag( vec_idx_t* max_idx_vec = reinterpret_cast(max_index); vec_len = vec_len / vec_size; - BatchKernelConfig cfg = { - bag_num, vec_len, 1, bag_num, true, BatchKernelConfig::Policy::pAdaptive}; - cfg.template build(); + BatchKernelConfig cfg = BatchKernelConfig::make_config( + bag_num, vec_len, 1, bag_num, true, BatchKernelConfig::Policy::pAdaptive); index_t fixing_bag_size = ignore_offsets ? index_size / bag_num : 0; auto kfn = KernelClass( diff --git a/src/ATen/native/xpu/sycl/Indexing.h b/src/ATen/native/xpu/sycl/Indexing.h index d78ddfbf6..fcb429ef5 100644 --- a/src/ATen/native/xpu/sycl/Indexing.h +++ b/src/ATen/native/xpu/sycl/Indexing.h @@ -866,21 +866,17 @@ void launch_index_put_deterministic_kernel( return; } int64_t v_stride_before = numel * stride; - BatchKernelConfig cfg = { + // align with precision of CPU backend. + using accscalar_t = scalar_t; /* acc_type; */ + using KernelClass = IndexPutDeterministicKernelFunctor; + BatchKernelConfig cfg = BatchKernelConfig::make_config( /* num of indices */ numel, /* num of elements to put per indices */ outer_dim * stride, 1, numel, true, {BatchKernelConfig::Policy::pSegment, - BatchKernelConfig::Policy::pAggressiveSplit}}; - - // align with precision of CPU backend. - using accscalar_t = scalar_t; /* acc_type; */ - using KernelClass = IndexPutDeterministicKernelFunctor; - - cfg.template build(); - + BatchKernelConfig::Policy::pAggressiveSplit}); KernelClass kfn( sorted_indices, indices, diff --git a/src/ATen/native/xpu/sycl/WeightNormKernels.cpp b/src/ATen/native/xpu/sycl/WeightNormKernels.cpp new file mode 100644 index 000000000..dd93f68c3 --- /dev/null +++ b/src/ATen/native/xpu/sycl/WeightNormKernels.cpp @@ -0,0 +1,1027 @@ +#include +#include +#include +#include +#include +#include +#include +#include "comm/Runtime.h" + +namespace at::native::xpu { + +template +struct ReduceAdd { + T operator()(const T a, const T b) const { + return a + b; + } +}; + +template < + class ScalarTypeInfo, + class AccTypeInfo, + typename scalar_t, + typename accscalar_t, + typename vec_t> +struct WeightNormReduceKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + auto id = cfg_.get_item_desc(item); + int64_t si = id.glb_batch % cfg_.stride_; + int64_t bi = id.glb_batch / cfg_.stride_; + int64_t ldr_pi = id.chunk * id.chunk_size + id.chunk_off; + int64_t str_pi = id.chunk; + int64_t ldr_lid = + si + ldr_pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; + int64_t ldr_off = at::xpu::detail::IndexToOffset::get( + ldr_lid, + iinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + int64_t str_lid = + si + str_pi * cfg_.stride_ + bi * id.chunk_num * cfg_.stride_; + int64_t str_off = at::xpu::detail::IndexToOffset::get( + str_lid, + oinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + accscalar_t value = 0; + if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { + value = (accscalar_t)iinfo_.data[ldr_off]; + if (need_squre_) + value *= value; + } + + if (cfg_.problem_along_x_) { + value = group_x_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } else { + value = group_y_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } + + if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { + if (id.chunk_off == 0) { + oinfo_.data[str_off] = is_final_ ? sqrtf(value) : value; + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(shared_memeory_size_, cgh); + } + WeightNormReduceKernelFunctor( + ScalarTypeInfo iinfo, + AccTypeInfo oinfo, + BatchKernelConfig cfg, + bool need_squre, + bool is_final, + int64_t shared_memeory_size) + : iinfo_(iinfo), + oinfo_(oinfo), + cfg_(cfg), + need_squre_(need_squre), + is_final_(is_final), + shared_memeory_size_(shared_memeory_size) {} + + private: + ScalarTypeInfo iinfo_; + AccTypeInfo oinfo_; + BatchKernelConfig cfg_; + bool need_squre_; + bool is_final_; + int64_t shared_memeory_size_; + sycl_local_acc_t shared_; +}; + +template +static inline void launch_weight_norm_reduce_kernel( + ScalarTypeInfo& iinfo, + AccTypeInfo& oinfo, + BatchKernelConfig& cfg, + bool need_squre, + bool is_final) { + using scalar_t = typename ScalarTypeInfo::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using vec_t = at::detail::Array; + + WeightNormReduceKernelFunctor< + ScalarTypeInfo, + AccTypeInfo, + scalar_t, + accscalar_t, + vec_t> + kfn(iinfo, oinfo, cfg, need_squre, is_final, cfg.group_size().size()); + sycl_kernel_submit( + cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); +} + +template +static inline void weight_norm_reduce( + ScalarTypeInfo& vinfo, + AccTypeInfo& ninfo, + int dim_after_collapse, + bool need_square) { + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = vinfo.strides[dim_after_collapse] == 1 ? true : false; + using scalar_t = typename ScalarTypeInfo::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using vec_t = at::detail::Array; + using KernelClass = WeightNormReduceKernelFunctor< + ScalarTypeInfo, + AccTypeInfo, + scalar_t, + accscalar_t, + vec_t>; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + batch, problem, stride, batch * stride, problem_along_x); + + if (cfg.problem_ <= cfg.problem_wg_range_) { + launch_weight_norm_reduce_kernel(vinfo, ninfo, cfg, need_square, true); + return; + } + + Tensor carrier = at::empty( + {cfg.batch_, cfg.problem_glb_range_ / cfg.problem_wg_range_, cfg.stride_}, + map_options()); + auto cinfo = + at::xpu::detail::getTensorInfo( + carrier); + launch_weight_norm_reduce_kernel(vinfo, cinfo, cfg, need_square, false); + + weight_norm_reduce(cinfo, ninfo, 1, false); + return; +} + +template < + class ScalarTypeInfo, + class AccTypeInfo, + typename scalar_t, + typename accscalar_t> +struct SegmentWeightNormKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + auto id = cfg_.get_item_desc(item); + int64_t si = id.glb_batch % cfg_.stride_; + int64_t bi = id.glb_batch / cfg_.stride_; + int64_t pi = id.chunk * id.chunk_size + id.chunk_off; + int64_t w_lid = si + pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; + int64_t n_lid = id.glb_batch; + + int64_t v_off = at::xpu::detail::IndexToOffset::get( + w_lid, + vinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t w_off = at::xpu::detail::IndexToOffset::get( + w_lid, + winfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t g_off = at::xpu::detail::IndexToOffset::get( + n_lid, + ginfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t n_off = at::xpu::detail::IndexToOffset::get( + n_lid, + ninfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { + winfo_.data[w_off] = + (1.f / ninfo_.data[n_off]) * vinfo_.data[v_off] * ginfo_.data[g_off]; + } + } + SegmentWeightNormKernelFunctor( + ScalarTypeInfo vinfo, + ScalarTypeInfo ginfo, + ScalarTypeInfo winfo, + AccTypeInfo ninfo, + BatchKernelConfig cfg) + : vinfo_(vinfo), ginfo_(ginfo), winfo_(winfo), ninfo_(ninfo), cfg_(cfg) {} + + private: + ScalarTypeInfo vinfo_; + ScalarTypeInfo ginfo_; + ScalarTypeInfo winfo_; + AccTypeInfo ninfo_; + BatchKernelConfig cfg_; +}; + +template +static inline void segment_weight_norm( + ScalarTypeInfo& vinfo, + ScalarTypeInfo& ginfo, + ScalarTypeInfo& winfo, + AccTypeInfo& ninfo, + int dim_after_collapse) { + // segment reduce for statistics + weight_norm_reduce(vinfo, ninfo, dim_after_collapse, true); + + // normalization + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = vinfo.strides[dim_after_collapse] == 1 ? true : false; + using scalar_t = typename ScalarTypeInfo::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + + using KernelClass = SegmentWeightNormKernelFunctor< + ScalarTypeInfo, + AccTypeInfo, + scalar_t, + accscalar_t>; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + batch, problem, stride, batch * stride, problem_along_x); + + KernelClass kfn(vinfo, ginfo, winfo, ninfo, cfg); + sycl_kernel_submit( + cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); +} + +template < + class ScalarTypeInfo, + class AccTypeInfo, + typename scalar_t, + typename accscalar_t, + typename vec_t> +struct WeightNormKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + auto id = cfg_.get_item_desc(item); + int64_t n_lid = id.glb_batch; + + int64_t g_off = at::xpu::detail::IndexToOffset::get( + n_lid, + ginfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t n_off = at::xpu::detail::IndexToOffset::get( + n_lid, + ninfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t si = id.glb_batch % cfg_.stride_; + int64_t bi = id.glb_batch / cfg_.stride_; + int64_t pi = id.chunk_off; + bi = si + bi * cfg_.problem_ * cfg_.stride_; + + accscalar_t value = 0; + if (id.glb_batch < cfg_.problem_batch_) { + for (int pi_ = pi; pi_ < cfg_.problem_; pi_ += cfg_.problem_wg_range_) { + int64_t v_lid = bi + pi_ * cfg_.stride_; + int64_t v_off = at::xpu::detail::IndexToOffset::get( + v_lid, + vinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + accscalar_t v = (accscalar_t)vinfo_.data[v_off]; + value += v * v; + } + } + + if (cfg_.problem_along_x_) { + value = group_x_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } else { + value = group_y_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } + + int n_slid = (int)id.glb_batch % batch_wg_range_; + if (id.glb_batch < cfg_.problem_batch_ && id.chunk_off == 0) { + value = sqrtf(value); + ninfo_.data[n_off] = value; + shared_[n_slid] = value; + } + // Here using slm instead. If using ugm, need fence w/ + // order:acq_rel & scope:workgroup & space:global_mem. + item.barrier(sycl_local_fence); + + if (id.glb_batch < cfg_.problem_batch_) { + for (int pi_ = pi; pi_ < cfg_.problem_; pi_ += cfg_.problem_wg_range_) { + int64_t v_lid = bi + pi_ * cfg_.stride_; + int64_t v_off = at::xpu::detail::IndexToOffset::get( + v_lid, + vinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + int64_t w_off = at::xpu::detail::IndexToOffset::get( + v_lid, + winfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + winfo_.data[w_off] = + (1.f / shared_[n_slid]) * vinfo_.data[v_off] * ginfo_.data[g_off]; + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(wg_size_, cgh); + } + + WeightNormKernelFunctor( + ScalarTypeInfo vinfo, + ScalarTypeInfo ginfo, + ScalarTypeInfo winfo, + AccTypeInfo ninfo, + BatchKernelConfig cfg, + int wg_size, + int batch_wg_range) + : vinfo_(vinfo), + ginfo_(ginfo), + winfo_(winfo), + ninfo_(ninfo), + cfg_(cfg), + wg_size_(wg_size), + batch_wg_range_(batch_wg_range) {} + + private: + ScalarTypeInfo vinfo_; + ScalarTypeInfo ginfo_; + ScalarTypeInfo winfo_; + AccTypeInfo ninfo_; + BatchKernelConfig cfg_; + int wg_size_; + int batch_wg_range_; + sycl_local_acc_t shared_; +}; + +template +static inline void weight_norm( + ScalarTypeInfo& vinfo, + ScalarTypeInfo& ginfo, + ScalarTypeInfo& winfo, + AccTypeInfo& ninfo, + int dim_after_collapse) { + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = vinfo.strides[dim_after_collapse] == 1 ? true : false; + using scalar_t = typename ScalarTypeInfo::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using vec_t = at::detail::Array; + + using KernelClass = WeightNormKernelFunctor< + ScalarTypeInfo, + AccTypeInfo, + scalar_t, + accscalar_t, + vec_t>; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + batch, + problem, + stride, + batch * stride, + problem_along_x, + BatchKernelConfig::Policy::pLoop); + + int wg_size = cfg.group_size().size(); + int batch_wg_range = wg_size / cfg.problem_wg_range_; + KernelClass kfn(vinfo, ginfo, winfo, ninfo, cfg, wg_size, batch_wg_range); + sycl_kernel_submit( + cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); + + return; +} + +std::tuple weight_norm_kernel( + const Tensor& v, + const Tensor& g, + int64_t dim) { + TORCH_INTERNAL_ASSERT( + dim == 0 || dim == v.dim() - 1, + "fused kernels can only be applied for first or last dim"); + + at::ScalarType scalar_acc_t = (g.scalar_type() == at::ScalarType::Half || + g.scalar_type() == at::ScalarType::BFloat16) + ? at::ScalarType::Float + : g.scalar_type(); + auto norms = at::empty( + g.sizes(), g.options().dtype(scalar_acc_t), g.suggest_memory_format()); + auto w = at::empty(v.sizes(), v.options(), v.suggest_memory_format()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + v.scalar_type(), + "aten::weight_norm", + [&] { + auto vinfo = at::xpu::detail::getTensorInfo(v); + int dim_after_collapse = vinfo.collapseDims(dim); + auto ginfo = at::xpu::detail::getTensorInfo(g); + ginfo.collapseDims(); + + auto winfo = at::xpu::detail::getTensorInfo(w); + winfo.collapseDims(dim); + using accscalar_t = acc_type; + auto ninfo = + at::xpu::detail::getTensorInfo(norms); + ninfo.collapseDims(); + dim_after_collapse = 1 - dim_after_collapse; // remain dim + + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = + vinfo.strides[dim_after_collapse] == 1 ? true : false; + if (BatchKernelConfig::Policy::pSegment == + BatchKernelConfig::suggest_policy( + batch, problem, stride, problem_along_x)) { + segment_weight_norm(vinfo, ginfo, winfo, ninfo, dim_after_collapse); + } else { + weight_norm(vinfo, ginfo, winfo, ninfo, dim_after_collapse); + } + }); + + return {w, norms}; +} + +template < + bool is_first, + class ScalarType1Info, + class ScalarType2Info, + class AccTypeInfo, + typename scalar1_t, + typename scalar2_t, + typename accscalar_t, + typename vec_t> +struct WeightNormBackwardReduceKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + auto id = cfg_.get_item_desc(item); + int64_t si = id.glb_batch % cfg_.stride_; + int64_t bi = id.glb_batch / cfg_.stride_; + int64_t i_pi = id.chunk * id.chunk_size + id.chunk_off; + int64_t o_pi = id.chunk; + + int64_t i_lid = + si + i_pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; + int64_t i1_off = at::xpu::detail::IndexToOffset::get( + i_lid, + i1info_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + int64_t i2_off; + if (is_first) { + i2_off = at::xpu::detail::IndexToOffset::get( + i_lid, + i2info_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + } + + int64_t o_lid = si + o_pi * cfg_.stride_ + bi * id.chunk_num * cfg_.stride_; + int64_t o_off = at::xpu::detail::IndexToOffset::get( + o_lid, + oinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + accscalar_t value = 0; + if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { + if (is_first) { + auto value1 = (accscalar_t)i1info_.data[i1_off]; + auto value2 = (accscalar_t)i2info_.data[i2_off]; + value = value1 * value2; + } else { + value = (accscalar_t)i1info_.data[i1_off]; + } + } + + if (cfg_.problem_along_x_) { + value = group_x_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } else { + value = group_y_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } + + if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { + if (id.chunk_off == 0) { + oinfo_.data[o_off] = value; + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(local_size_, cgh); + } + WeightNormBackwardReduceKernelFunctor( + ScalarType1Info i1info, + ScalarType2Info i2info, + AccTypeInfo oinfo, + BatchKernelConfig cfg, + int64_t local_size) + : i1info_(i1info), + i2info_(i2info), + oinfo_(oinfo), + cfg_(cfg), + local_size_(local_size) {} + + private: + ScalarType1Info i1info_; + ScalarType2Info i2info_; + AccTypeInfo oinfo_; + BatchKernelConfig cfg_; + int64_t local_size_; + sycl_local_acc_t shared_; +}; + +template < + bool is_first, + class ScalarType1Info, + class ScalarType2Info, + class AccTypeInfo> +static inline void launch_weight_norm_backward_reduce_kernel( + ScalarType1Info& i1info, + ScalarType2Info& i2info, + AccTypeInfo& oinfo, + BatchKernelConfig& cfg) { + using scalar1_t = typename ScalarType1Info::scalar_t; + using scalar2_t = typename ScalarType2Info::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using vec_t = at::detail::Array; + WeightNormBackwardReduceKernelFunctor< + is_first, + ScalarType1Info, + ScalarType2Info, + AccTypeInfo, + scalar1_t, + scalar2_t, + accscalar_t, + vec_t> + kfn(i1info, i2info, oinfo, cfg, cfg.group_size().size()); + sycl_kernel_submit( + cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); +} + +template +static inline void weight_norm_backward_reduce( + ScalarType1Info& vinfo, + ScalarType2Info& gwinfo, + AccTypeInfo& rinfo, + int dim_after_collapse, + bool is_first) { + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = vinfo.strides[dim_after_collapse] == 1 ? true : false; + + using scalar1_t = typename ScalarType1Info::scalar_t; + using scalar2_t = typename ScalarType2Info::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using vec_t = at::detail::Array; + using KernelClass = WeightNormBackwardReduceKernelFunctor< + true, + ScalarType1Info, + ScalarType2Info, + AccTypeInfo, + scalar1_t, + scalar2_t, + accscalar_t, + vec_t>; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + batch, problem, stride, batch * stride, problem_along_x); + if (cfg.problem_ <= cfg.problem_wg_range_) { + if (is_first) { + launch_weight_norm_backward_reduce_kernel( + vinfo, gwinfo, rinfo, cfg); + } else { + launch_weight_norm_backward_reduce_kernel( + vinfo, gwinfo, rinfo, cfg); + } + return; + } + + Tensor carrier = at::empty( + {cfg.batch_, cfg.problem_glb_range_ / cfg.problem_wg_range_, cfg.stride_}, + map_options()); + auto cinfo = + at::xpu::detail::getTensorInfo( + carrier); + if (is_first) { + launch_weight_norm_backward_reduce_kernel(vinfo, gwinfo, cinfo, cfg); + } else { + launch_weight_norm_backward_reduce_kernel(vinfo, gwinfo, cinfo, cfg); + } + + weight_norm_backward_reduce(cinfo, gwinfo, rinfo, 1, false); + return; +} + +template < + class ScalarTypeInfo, + class AccTypeInfo, + typename scalar_t, + typename accscalar_t> +struct SegmentWeightNormBackwardKernelFunctor { + void operator()(sycl::nd_item<2> item) const { + auto id = cfg_.get_item_desc(item); + + int64_t si = id.glb_batch % cfg_.stride_; + int64_t bi = id.glb_batch / cfg_.stride_; + int64_t pi = id.chunk * id.chunk_size + id.chunk_off; + + int64_t gv_lid = si + pi * cfg_.stride_ + bi * cfg_.problem_ * cfg_.stride_; + int64_t gg_lid = id.glb_batch; + + int64_t v_off = at::xpu::detail::IndexToOffset::get( + gv_lid, + vinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t gw_off = at::xpu::detail::IndexToOffset::get( + gv_lid, + gwinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t gv_off = at::xpu::detail::IndexToOffset::get( + gv_lid, + gvinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t g_off = at::xpu::detail::IndexToOffset::get( + gg_lid, + ginfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t n_off = at::xpu::detail::IndexToOffset::get( + gg_lid, + ninfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t r_off = at::xpu::detail::IndexToOffset::get( + gg_lid, + rinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + int64_t gg_off = at::xpu::detail::IndexToOffset::get( + gg_lid, + gginfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + if (id.glb_problem < cfg_.problem_ && id.glb_batch < cfg_.problem_batch_) { + accscalar_t g = ginfo_.data[g_off]; + accscalar_t gw = gwinfo_.data[gw_off]; + accscalar_t v = vinfo_.data[v_off]; + accscalar_t n = 1.f / ninfo_.data[n_off]; + accscalar_t r = rinfo_.data[r_off]; + accscalar_t gg = r * n; + accscalar_t n3 = n * n * n; + accscalar_t gv = g * (n * gw - n3 * v * r); + + gvinfo_.data[gv_off] = static_cast(gv); + if (id.chunk == 0 && id.chunk_off == 0) + gginfo_.data[gg_off] = static_cast(gg); + } + } + SegmentWeightNormBackwardKernelFunctor( + ScalarTypeInfo vinfo, + ScalarTypeInfo ginfo, + ScalarTypeInfo gwinfo, + AccTypeInfo ninfo, + ScalarTypeInfo gvinfo, + ScalarTypeInfo gginfo, + AccTypeInfo rinfo, + BatchKernelConfig cfg) + : vinfo_(vinfo), + ginfo_(ginfo), + gwinfo_(gwinfo), + ninfo_(ninfo), + gvinfo_(gvinfo), + gginfo_(gginfo), + rinfo_(rinfo), + cfg_(cfg) {} + + private: + ScalarTypeInfo vinfo_; + ScalarTypeInfo ginfo_; + ScalarTypeInfo gwinfo_; + AccTypeInfo ninfo_; + ScalarTypeInfo gvinfo_; + ScalarTypeInfo gginfo_; + AccTypeInfo rinfo_; + BatchKernelConfig cfg_; +}; + +template +static inline void segment_weight_norm_backward( + ScalarTypeInfo& vinfo, + ScalarTypeInfo& ginfo, + ScalarTypeInfo& gwinfo, + AccTypeInfo& ninfo, + ScalarTypeInfo& gvinfo, + ScalarTypeInfo& gginfo, + AccTypeInfo& rinfo, + int dim_after_collapse) { + // segment reduce + weight_norm_backward_reduce(vinfo, gwinfo, rinfo, dim_after_collapse, true); + + // compute gradient + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = vinfo.strides[dim_after_collapse] == 1 ? true : false; + + using scalar_t = typename ScalarTypeInfo::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using KernelClass = SegmentWeightNormBackwardKernelFunctor< + ScalarTypeInfo, + AccTypeInfo, + scalar_t, + accscalar_t>; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + batch, problem, stride, batch * stride, problem_along_x); + + KernelClass kfn(vinfo, ginfo, gwinfo, ninfo, gvinfo, gginfo, rinfo, cfg); + sycl_kernel_submit( + cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); + + return; +} + +template < + class ScalarTypeInfo, + class AccTypeInfo, + typename scalar_t, + typename accscalar_t, + typename vec_t> +struct WeightNormBackwardKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<2> item) const { + auto id = cfg_.get_item_desc(item); + int64_t n_lid = id.glb_batch; + int64_t g_off = at::xpu::detail::IndexToOffset::get( + n_lid, + ginfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + int64_t gg_off = at::xpu::detail::IndexToOffset::get( + n_lid, + gginfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + int64_t n_off = at::xpu::detail::IndexToOffset::get( + n_lid, + ninfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + int64_t si = id.glb_batch % cfg_.stride_; + int64_t bi = id.glb_batch / cfg_.stride_; + int64_t pi = id.chunk_off; + bi = si + bi * cfg_.problem_ * cfg_.stride_; + + accscalar_t value = 0; + if (id.glb_batch < cfg_.problem_batch_) { + for (int pi_ = pi; pi_ < cfg_.problem_; pi_ += cfg_.problem_wg_range_) { + int64_t v_lid, v_off, gw_off; + v_lid = bi + pi_ * cfg_.stride_; + + v_off = at::xpu::detail::IndexToOffset::get( + v_lid, + vinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + gw_off = at::xpu::detail::IndexToOffset::get( + v_lid, + gwinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + accscalar_t v = (accscalar_t)vinfo_.data[v_off]; + accscalar_t gw = (accscalar_t)gwinfo_.data[gw_off]; + value += v * gw; + } + } + + if (cfg_.problem_along_x_) { + value = group_x_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } else { + value = group_y_reduce( + item, shared_, vec_t(value), ReduceAdd())[0]; + } + + int n_slid = (int)id.glb_batch % batch_wg_range_; + if (id.glb_batch < cfg_.problem_batch_ && id.chunk_off == 0) { + shared_[n_slid] = value; + } + item.barrier(sycl_local_fence); + + if (id.glb_batch < cfg_.problem_batch_) { + for (int pi_ = pi; pi_ < cfg_.problem_; pi_ += cfg_.problem_wg_range_) { + int64_t v_lid, v_off, gw_off, gv_off; + v_lid = bi + pi_ * cfg_.stride_; + + v_off = at::xpu::detail::IndexToOffset::get( + v_lid, + vinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + gw_off = at::xpu::detail::IndexToOffset::get( + v_lid, + gwinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + gv_off = at::xpu::detail::IndexToOffset::get( + v_lid, + gvinfo_, + at::xpu::detail::IndexToOffset:: + NON_STRICT_CONTIGUOUS); + + accscalar_t g = ginfo_.data[g_off]; + accscalar_t gw = gwinfo_.data[gw_off]; + accscalar_t v = vinfo_.data[v_off]; + accscalar_t n = 1.f / ninfo_.data[n_off]; + accscalar_t r = shared_[n_slid]; + accscalar_t gg = r * n; + accscalar_t n3 = n * n * n; + accscalar_t gv = g * (n * gw - n3 * v * r); + + gvinfo_.data[gv_off] = static_cast(gv); + if (id.chunk_off == 0) + gginfo_.data[gg_off] = static_cast(gg); + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(wg_size_, cgh); + } + + WeightNormBackwardKernelFunctor( + ScalarTypeInfo vinfo, + ScalarTypeInfo ginfo, + ScalarTypeInfo gwinfo, + AccTypeInfo ninfo, + ScalarTypeInfo gvinfo, + ScalarTypeInfo gginfo, + BatchKernelConfig cfg, + int wg_size, + int batch_wg_range) + : vinfo_(vinfo), + ginfo_(ginfo), + gwinfo_(gwinfo), + ninfo_(ninfo), + gvinfo_(gvinfo), + gginfo_(gginfo), + cfg_(cfg), + wg_size_(wg_size), + batch_wg_range_(batch_wg_range) {} + + private: + ScalarTypeInfo vinfo_; + ScalarTypeInfo ginfo_; + ScalarTypeInfo gwinfo_; + AccTypeInfo ninfo_; + ScalarTypeInfo gvinfo_; + ScalarTypeInfo gginfo_; + BatchKernelConfig cfg_; + int wg_size_; + int batch_wg_range_; + sycl_local_acc_t shared_; +}; + +template +static inline void weight_norm_backward( + ScalarTypeInfo& vinfo, + ScalarTypeInfo& ginfo, + ScalarTypeInfo& gwinfo, + AccTypeInfo& ninfo, + ScalarTypeInfo& gvinfo, + ScalarTypeInfo& gginfo, + int dim_after_collapse) { + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = vinfo.strides[dim_after_collapse] == 1 ? true : false; + + using scalar_t = typename ScalarTypeInfo::scalar_t; + using accscalar_t = typename AccTypeInfo::scalar_t; + using vec_t = at::detail::Array; + using KernelClass = WeightNormBackwardKernelFunctor< + ScalarTypeInfo, + AccTypeInfo, + scalar_t, + accscalar_t, + vec_t>; + BatchKernelConfig cfg = BatchKernelConfig::make_config( + batch, + problem, + stride, + batch * stride, + problem_along_x, + BatchKernelConfig::Policy::pLoop); + int wg_size = cfg.group_size().size(); + int batch_wg_range = wg_size / cfg.problem_wg_range_; + KernelClass kfn( + vinfo, + ginfo, + gwinfo, + ninfo, + gvinfo, + gginfo, + cfg, + wg_size, + batch_wg_range); + sycl_kernel_submit( + cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); + return; +} + +std::tuple weight_norm_backward_kernel( + const Tensor& grad_w, + const Tensor& saved_v, + const Tensor& saved_g, + const Tensor& saved_norms, + int64_t dim) { + auto grad_v = at::empty_like(saved_v, c10::get_contiguous_memory_format()); + auto grad_g = at::empty_like(saved_g, c10::get_contiguous_memory_format()); + + at::ScalarType scalar_acc_t = + (saved_g.scalar_type() == at::ScalarType::Half || + saved_g.scalar_type() == at::ScalarType::BFloat16) + ? at::ScalarType::Float + : saved_g.scalar_type(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + saved_v.scalar_type(), + "aten::weight_norm_backward", + [&] { + auto vinfo = at::xpu::detail::getTensorInfo(saved_v); + int dim_after_collapse = vinfo.collapseDims(dim); + + auto ginfo = at::xpu::detail::getTensorInfo(saved_g); + ginfo.collapseDims(); + + auto gwinfo = at::xpu::detail::getTensorInfo(grad_w); + gwinfo.collapseDims(dim); + using accscalar_t = acc_type; + auto ninfo = + at::xpu::detail::getTensorInfo(saved_norms); + ninfo.collapseDims(); + + auto gvinfo = at::xpu::detail::getTensorInfo(grad_v); + gvinfo.collapseDims(dim); + + auto gginfo = at::xpu::detail::getTensorInfo(grad_g); + gginfo.collapseDims(); + + dim_after_collapse = 1 - dim_after_collapse; // remain dim + + int64_t batch = vinfo.outerSize(dim_after_collapse); + int64_t problem = vinfo.sizes[dim_after_collapse]; + int64_t stride = vinfo.innerSize(dim_after_collapse); + bool problem_along_x = + vinfo.strides[dim_after_collapse] == 1 ? true : false; + if (BatchKernelConfig::Policy::pSegment == + BatchKernelConfig::suggest_policy( + batch, problem, stride, problem_along_x)) { + auto reduce = at::empty( + saved_g.sizes(), + saved_g.options().dtype(scalar_acc_t), + c10::get_contiguous_memory_format()); + auto rinfo = + at::xpu::detail::getTensorInfo(reduce); + rinfo.collapseDims(); + + segment_weight_norm_backward( + vinfo, + ginfo, + gwinfo, + ninfo, + gvinfo, + gginfo, + rinfo, + dim_after_collapse); + } else { + weight_norm_backward( + vinfo, ginfo, gwinfo, ninfo, gvinfo, gginfo, dim_after_collapse); + } + }); + + return {grad_v, grad_g}; +} + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/WeightNormKernels.h b/src/ATen/native/xpu/sycl/WeightNormKernels.h new file mode 100644 index 000000000..5b0d7afd2 --- /dev/null +++ b/src/ATen/native/xpu/sycl/WeightNormKernels.h @@ -0,0 +1,16 @@ +#pragma once +#include + +namespace at::native::xpu { +std::tuple weight_norm_kernel( + const Tensor& v, + const Tensor& g, + int64_t dim); + +std::tuple weight_norm_backward_kernel( + const Tensor& grad_w, + const Tensor& saved_v, + const Tensor& saved_g, + const Tensor& saved_norms, + int64_t dim); +} // namespace at::native::xpu \ No newline at end of file diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 5c354e0b0..2ca8e8fc1 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -694,4 +694,6 @@ supported: - renorm.out - renorm_ - nan_to_num.out + - _weight_norm_interface + - _weight_norm_interface_backward - range.out