Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
hjhee committed Jul 10, 2024
1 parent 0013b22 commit 023fd27
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 25 deletions.
12 changes: 0 additions & 12 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,18 +502,6 @@ Tensor& XPUNativeFunctions::floor_divide_(Tensor& self, const Tensor& other) {
return XPUNativeFunctions::floor_divide_out(self, other, self);
}

Tensor XPUNativeFunctions::floor_divide(
const Tensor& self,
const Scalar& other) {
auto wrapper = native::wrapped_scalar_tensor(other);
return XPUNativeFunctions::floor_divide(self, wrapper);
}

Tensor& XPUNativeFunctions::floor_divide_(Tensor& self, const Scalar& other) {
auto wrapper = native::wrapped_scalar_tensor(other);
return XPUNativeFunctions::floor_divide_out(self, wrapper, self);
}

TensorIterator meta_fmin_fmax(
const char* const name,
const Tensor& self,
Expand Down
11 changes: 3 additions & 8 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,9 @@ struct FloorFunctor<c10::complex<T>> {
};

void floor_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
iter.common_dtype(),
"floor_xpu",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
gpu_kernel(iter, FloorFunctor<opmath_t>());
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "floor_xpu", [&]() {
gpu_kernel(iter, FloorFunctor<scalar_t>());
});
}

Expand Down
4 changes: 4 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@
"renorm",
"lerp",
"conj_physical",
"fmax",
"fmin",
"floor",
"floor_divide"
]


Expand Down
5 changes: 0 additions & 5 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -506,23 +506,18 @@ supported:
- randperm.generator_out
- _amp_foreach_non_finite_check_and_unscale_
- _amp_update_scale_
<<<<<<< HEAD
- floor
- floor_
- floor.out
- floor_divide
- floor_divide_.Tensor
- floor_divide.out
- floor_divide.Scalar
- floor_divide_.Scalar
- fmax
- fmax.out
- fmin
- fmin.out
=======
- conj_physical.out
- conj_physical_
- ceil
- ceil_
- ceil.out
>>>>>>> origin/main

0 comments on commit 023fd27

Please sign in to comment.