diff --git a/src/ATen/native/xpu/UnaryOps.cpp b/src/ATen/native/xpu/UnaryOps.cpp index 4d07a466b..ffc528fab 100644 --- a/src/ATen/native/xpu/UnaryOps.cpp +++ b/src/ATen/native/xpu/UnaryOps.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -515,6 +516,18 @@ Tensor& XPUNativeFunctions::erfc_out(const Tensor& self, Tensor& out) { return out; } +Tensor& XPUNativeFunctions::conj_physical_out(const Tensor& self, Tensor& out) { + auto iter = TensorIterator::unary_op(out, self); + native::xpu::conj_physical_kernel(iter); + return out; +} + +Tensor& XPUNativeFunctions::conj_physical_(Tensor& self) { + if (!self.is_complex()) + return self; + return XPUNativeFunctions::conj_physical_out(self, self); +} + TensorIterator ceil_meta(const Tensor& self, Tensor& out) { TORCH_CHECK(!self.is_complex(), "ceil is not supported for complex inputs"); TensorIterator iter; diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 496eb00f1..0ece10206 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -184,7 +184,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "cholesky", "cholesky_inverse", "_cholesky_solve_helper", - "conj_physical.out", "copysign.out", "cosh.out", "count_nonzero.dim_IntList", diff --git a/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp b/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp index e082096c1..87de57a3a 100644 --- a/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnaryComplexKernels.cpp @@ -30,6 +30,32 @@ void conj_kernel(TensorIterator& iter) { })); } +template +struct ConjPhysicalFunctor { + scalar_t operator()(scalar_t z) const { + return std::conj(z); + } +}; + +template +struct ConjPhysicalFunctor> { + c10::complex operator()(c10::complex z) const { + return c10::complex(z.real(), -z.imag()); + } +}; + +void conj_physical_kernel(TensorIterator& iter) { + AT_DISPATCH_SWITCH( + iter.common_dtype(), + "conj_xpu", + AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] { + // Conj is a no-op for non-complex types + copy_kernel(iter); + }) AT_DISPATCH_CASE_COMPLEX_TYPES_AND(kComplexHalf, [&] { + gpu_kernel(iter, ConjPhysicalFunctor()); + })); +} + template struct NegConjScalarFunc { scalar_t operator()(scalar_t src_val) const { diff --git a/src/ATen/native/xpu/sycl/UnaryComplexKernels.h b/src/ATen/native/xpu/sycl/UnaryComplexKernels.h index 8d19381b3..d3ad4fe15 100644 --- a/src/ATen/native/xpu/sycl/UnaryComplexKernels.h +++ b/src/ATen/native/xpu/sycl/UnaryComplexKernels.h @@ -6,6 +6,8 @@ namespace at::native::xpu { void conj_kernel(TensorIterator& iter); +void conj_physical_kernel(TensorIterator& iter); + void neg_conj_kernel(TensorIterator& iter); void neg_kernel(TensorIterator& iter); diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 9287a85b6..02e2542c8 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -207,7 +207,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_python_ref_torch_fallback__refs_square_xpu_bool", "test_python_ref_torch_fallback__refs_vdot_xpu_complex128", "test_python_ref_torch_fallback__refs_vdot_xpu_complex64", - "test_variant_consistency_eager_conj_physical_xpu_complex64", "test_variant_consistency_eager_nn_functional_conv_transpose2d_xpu_complex64", "test_variant_consistency_eager_nn_functional_conv_transpose2d_xpu_float32", "test_variant_consistency_eager_nn_functional_conv_transpose3d_xpu_complex64", @@ -242,8 +241,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_python_ref_executor__refs_square_executor_aten_xpu_complex128", "test_python_ref_torch_fallback__refs_square_xpu_complex128", "test_python_ref_torch_fallback__refs_square_xpu_complex64", - "test_conj_view_conj_physical_xpu_complex64", - "test_neg_conj_view_conj_physical_xpu_complex128", # Skip list of new added when porting XPU operators. # See: https://github.com/intel/torch-xpu-ops/issues/128 @@ -2207,9 +2204,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): # torch.autograd.gradcheck.GradcheckError: Jacobian computed with forward mode mismatch for output 0 with respect to input 0, "test_fn_fwgrad_bwgrad_nn_functional_rrelu_xpu_float64", "test_forward_mode_AD_nn_functional_rrelu_xpu_float64", - # RuntimeError: DispatchStub: unsupported device typexpu - "test_inplace_forward_mode_AD_conj_physical_xpu_complex128", - # NotImplementedError: Could not run 'aten::_to_dense' with arguments from the 'SparseXPU' backend. +# NotImplementedError: Could not run 'aten::_to_dense' with arguments from the 'SparseXPU' backend. "test_fn_fwgrad_bwgrad_to_sparse_xpu_float64", "test_forward_mode_AD_to_sparse_xpu_float64", ) @@ -2745,9 +2740,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): ### Error #7 in TestBwdGradientsXPU , totally 2 , NotImplementedError: Could not run 'aten::_sparse_coo_tensor_with_dims_and_tensors' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_sparse_coo_tensor_with_dims_and_tensors' is only available for these backends: [XPU, Meta, SparseCPU, SparseMeta, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher]. "test_fn_grad_to_sparse_xpu_float64", "test_fn_gradgrad_to_sparse_xpu_float64", - ### Error #8 in TestBwdGradientsXPU , totally 2 , RuntimeError: DispatchStub: unsupported device typexpu - "test_inplace_grad_conj_physical_xpu_complex128", - "test_inplace_gradgrad_conj_physical_xpu_complex128", ) res += launch_test("test_ops_gradients_xpu.py", skip_list) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 6511f4120..add5367fa 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -155,6 +155,7 @@ "bincount", "renorm", "lerp", + "conj_physical", ] diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2cd535394..10fd6748b 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -502,6 +502,8 @@ supported: - randperm.generator_out - _amp_foreach_non_finite_check_and_unscale_ - _amp_update_scale_ + - conj_physical.out + - conj_physical_ - ceil - ceil_ - ceil.out