From 1951fce0fc58eae0c3e1d491c85958b73c61575f Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Fri, 5 Jul 2024 14:55:21 +0800 Subject: [PATCH] Add aten::eye and its variant (#480) Signed-off-by: Feng Yuan --- src/ATen/native/xpu/TensorFactories.cpp | 19 +++++++++++++++++++ src/ATen/native/xpu/XPUFallback.template | 1 - test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 2 ++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/TensorFactories.cpp b/src/ATen/native/xpu/TensorFactories.cpp index ee29aa167..110590958 100644 --- a/src/ATen/native/xpu/TensorFactories.cpp +++ b/src/ATen/native/xpu/TensorFactories.cpp @@ -18,6 +18,25 @@ namespace at { +Tensor& XPUNativeFunctions::eye_out(int64_t n, Tensor& result) { + return XPUNativeFunctions::eye_out(n, n, result); +} + +Tensor& XPUNativeFunctions::eye_out(int64_t n, int64_t m, Tensor& result) { + TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); + TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); + + result.resize_({n, m}); + result.zero_(); + + int64_t sz = std::min(n, m); + int64_t stride = result.stride(0) + result.stride(1); + + Tensor diag = result.as_strided({sz}, {stride}); + diag.fill_(1); + return result; +} + Tensor XPUNativeFunctions::empty( IntArrayRef size, c10::optional dtype_opt, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 66f7dd905..496eb00f1 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -203,7 +203,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "exp2.out", "expm1.out", "exponential_", - "eye.m_out", "_fft_c2c", "_fft_c2r", "_fft_r2c", diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 6f296cbd0..6511f4120 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -19,6 +19,7 @@ _xpu_computation_op_list = [ "empty", + "eye", "fill", "zeros", "zeros_like", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 04c027ca6..2cd535394 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -190,6 +190,8 @@ supported: - exp_ - empty.memory_format - empty_strided + - eye.out + - eye.m_out - _efficientzerotensor - complex.out - clone