Skip to content

Commit 9fc5d81

Browse files
min-jean-chomin.jean.choxytintel
authored
Add aten::_foreach_clamp_max (#967)
- `_foreach_clamp_max.List` - `_foreach_clamp_max_.List` - `_foreach_clamp_max.Scalar` - `_foreach_clamp_max_.Scalar` - `_foreach_clamp_max.ScalarList` - `_foreach_clamp_max_.ScalarList` --------- Co-authored-by: min.jean.cho <[email protected]> Co-authored-by: Yutao Xu <[email protected]>
1 parent f592ff9 commit 9fc5d81

10 files changed

+81
-0
lines changed

src/ATen/native/xpu/ForeachOpList.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/ops/_foreach_add_native.h>
33
#include <ATen/ops/_foreach_addcdiv_native.h>
44
#include <ATen/ops/_foreach_addcmul_native.h>
5+
#include <ATen/ops/_foreach_clamp_max_native.h>
56
#include <ATen/ops/_foreach_div_native.h>
67
#include <ATen/ops/_foreach_lerp_native.h>
78
#include <ATen/ops/_foreach_mul_native.h>
@@ -65,6 +66,7 @@ namespace native {
6566
FOREACH_BINARY_OP_LIST_ALPHA(add);
6667
FOREACH_BINARY_OP_LIST(mul, false);
6768
FOREACH_BINARY_OP_LIST(div, true);
69+
FOREACH_BINARY_OP_LIST(clamp_max, true);
6870
FOREACH_BINARY_OP_LIST(clamp_min, true);
6971

7072
#define FOREACH_POINTWISE_OP_TENSOR(NAME) \

src/ATen/native/xpu/ForeachOpScalar.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/ops/_foreach_add_native.h>
33
#include <ATen/ops/_foreach_addcdiv_native.h>
44
#include <ATen/ops/_foreach_addcmul_native.h>
5+
#include <ATen/ops/_foreach_clamp_max_native.h>
56
#include <ATen/ops/_foreach_div_native.h>
67
#include <ATen/ops/_foreach_lerp_native.h>
78
#include <ATen/ops/_foreach_mul_native.h>
@@ -38,6 +39,7 @@ namespace native {
3839
FOREACH_BINARY_OP_SCALAR(add, /*div_op*/ false);
3940
FOREACH_BINARY_OP_SCALAR(mul, /*div_op*/ false);
4041
FOREACH_BINARY_OP_SCALAR(div, /*div_op*/ true);
42+
FOREACH_BINARY_OP_SCALAR(clamp_max, /*div_op*/ true);
4143
FOREACH_BINARY_OP_SCALAR(clamp_min, /*div_op*/ true);
4244

4345
#define FOREACH_POINTWISE_OP_SCALAR(NAME) \

src/ATen/native/xpu/ForeachOpScalarList.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/ops/_foreach_add_native.h>
33
#include <ATen/ops/_foreach_addcdiv_native.h>
44
#include <ATen/ops/_foreach_addcmul_native.h>
5+
#include <ATen/ops/_foreach_clamp_max_native.h>
56
#include <ATen/ops/_foreach_div_native.h>
67
#include <ATen/ops/_foreach_mul_native.h>
78
#include <ATen/ops/_foreach_clamp_min_native.h>
@@ -41,6 +42,7 @@ namespace native {
4142
FOREACH_BINARY_OP_SCALARLIST(add, /*div_op*/ false);
4243
FOREACH_BINARY_OP_SCALARLIST(mul, /*div_op*/ false);
4344
FOREACH_BINARY_OP_SCALARLIST(div, /*div_op*/ true);
45+
FOREACH_BINARY_OP_SCALARLIST(clamp_max, /*div_op*/ true);
4446
FOREACH_BINARY_OP_SCALARLIST(clamp_min, /*div_op*/ true);
4547

4648
#define FOREACH_POINTWISE_OP_SCALARLIST(NAME) \

src/ATen/native/xpu/sycl/ForeachBinaryOpListKernels.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ FOREACH_BINARY_LIST_KERNEL(div) {
182182
return all_types_complex_bool_half_bfloat16<std::divides>(tensor1, tensor2);
183183
}
184184

185+
FOREACH_BINARY_LIST_INPLACE_KERNEL(clamp_max) {
186+
return all_types_half_bfloat16_<foreach_internal::minimum>(tensor1, tensor2);
187+
}
188+
189+
FOREACH_BINARY_LIST_KERNEL(clamp_max) {
190+
return all_types_half_bfloat16<foreach_internal::minimum>(tensor1, tensor2);
191+
}
192+
185193
FOREACH_BINARY_LIST_INPLACE_KERNEL(clamp_min) {
186194
return all_types_half_bfloat16_<foreach_internal::maximum>(tensor1, tensor2);
187195
}

src/ATen/native/xpu/sycl/ForeachBinaryOpListKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ TORCH_XPU_API FOREACH_BINARY_LIST_INPLACE_KERNEL(mul);
3737
TORCH_XPU_API FOREACH_BINARY_LIST_KERNEL(mul);
3838
TORCH_XPU_API FOREACH_BINARY_LIST_INPLACE_KERNEL(div);
3939
TORCH_XPU_API FOREACH_BINARY_LIST_KERNEL(div);
40+
TORCH_XPU_API FOREACH_BINARY_LIST_INPLACE_KERNEL(clamp_max);
41+
TORCH_XPU_API FOREACH_BINARY_LIST_KERNEL(clamp_max);
4042
TORCH_XPU_API FOREACH_BINARY_LIST_INPLACE_KERNEL(clamp_min);
4143
TORCH_XPU_API FOREACH_BINARY_LIST_KERNEL(clamp_min);
4244

src/ATen/native/xpu/sycl/ForeachBinaryOpScalarKernels.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ FOREACH_BINARY_SCALAR_KERNEL(div) {
150150
return all_types_complex_bool_half_bfloat16<std::divides>(tensors, scalar);
151151
}
152152

153+
FOREACH_BINARY_SCALAR_INPLACE_KERNEL(clamp_max) {
154+
return all_types_half_bfloat16_<foreach_internal::minimum>(tensors, scalar);
155+
}
156+
157+
FOREACH_BINARY_SCALAR_KERNEL(clamp_max) {
158+
return all_types_half_bfloat16<foreach_internal::minimum>(tensors, scalar);
159+
}
160+
153161
FOREACH_BINARY_SCALAR_INPLACE_KERNEL(clamp_min) {
154162
return all_types_half_bfloat16_<foreach_internal::maximum>(tensors, scalar);
155163
}

src/ATen/native/xpu/sycl/ForeachBinaryOpScalarKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ TORCH_XPU_API FOREACH_BINARY_SCALAR_INPLACE_KERNEL(mul);
2323
TORCH_XPU_API FOREACH_BINARY_SCALAR_KERNEL(mul);
2424
TORCH_XPU_API FOREACH_BINARY_SCALAR_INPLACE_KERNEL(div);
2525
TORCH_XPU_API FOREACH_BINARY_SCALAR_KERNEL(div);
26+
TORCH_XPU_API FOREACH_BINARY_SCALAR_INPLACE_KERNEL(clamp_max);
27+
TORCH_XPU_API FOREACH_BINARY_SCALAR_KERNEL(clamp_max);
2628
TORCH_XPU_API FOREACH_BINARY_SCALAR_INPLACE_KERNEL(clamp_min);
2729
TORCH_XPU_API FOREACH_BINARY_SCALAR_KERNEL(clamp_min);
2830

src/ATen/native/xpu/sycl/ForeachBinaryOpScalarListKernels.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ FOREACH_BINARY_SCALARLIST_KERNEL(div) {
155155
return all_types_complex_bool_half_bfloat16<std::divides>(tensors, scalars);
156156
}
157157

158+
FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL(clamp_max) {
159+
return all_types_half_bfloat16_<foreach_internal::minimum>(tensors, scalars);
160+
}
161+
162+
FOREACH_BINARY_SCALARLIST_KERNEL(clamp_max) {
163+
return all_types_half_bfloat16<foreach_internal::minimum>(tensors, scalars);
164+
}
165+
158166
FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL(clamp_min) {
159167
return all_types_half_bfloat16_<foreach_internal::maximum>(tensors, scalars);
160168
}

src/ATen/native/xpu/sycl/ForeachBinaryOpScalarListKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ TORCH_XPU_API FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL(mul);
2323
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_KERNEL(mul);
2424
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL(div);
2525
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_KERNEL(div);
26+
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL(clamp_max);
27+
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_KERNEL(clamp_max);
2628
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_INPLACE_KERNEL(clamp_min);
2729
TORCH_XPU_API FOREACH_BINARY_SCALARLIST_KERNEL(clamp_min);
2830

yaml/native/native_functions.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,6 +2003,51 @@
20032003
XPU: foreach_tensor_div_scalar_kernel_xpu_
20042004
autogen: _foreach_div.Scalar_out
20052005

2006+
- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
2007+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
2008+
variants: function
2009+
dispatch:
2010+
CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow
2011+
XPU: foreach_tensor_clamp_max_scalar_kernel_xpu
2012+
2013+
- func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
2014+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
2015+
variants: function
2016+
dispatch:
2017+
CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_
2018+
XPU: foreach_tensor_clamp_max_scalar_kernel_xpu_
2019+
autogen: _foreach_clamp_max.Scalar_out
2020+
2021+
- func: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[]
2022+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
2023+
variants: function
2024+
dispatch:
2025+
CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow
2026+
XPU: foreach_tensor_clamp_max_list_kernel_xpu
2027+
2028+
- func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> ()
2029+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
2030+
variants: function
2031+
dispatch:
2032+
CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_
2033+
XPU: foreach_tensor_clamp_max_list_kernel_xpu_
2034+
autogen: _foreach_clamp_max.List_out
2035+
2036+
- func: _foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
2037+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
2038+
variants: function
2039+
dispatch:
2040+
CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow
2041+
XPU: foreach_tensor_clamp_max_scalarlist_kernel_xpu
2042+
2043+
- func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
2044+
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
2045+
variants: function
2046+
dispatch:
2047+
CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_
2048+
XPU: foreach_tensor_clamp_max_scalarlist_kernel_xpu_
2049+
autogen: _foreach_clamp_max.ScalarList_out
2050+
20062051
- func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
20072052
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
20082053
variants: function

0 commit comments

Comments
 (0)