Skip to content

Commit

Permalink
Add aten::eye and its variant (#480)
Browse files Browse the repository at this point in the history
Signed-off-by: Feng Yuan <[email protected]>
  • Loading branch information
fengyuan14 authored Jul 5, 2024
1 parent 6781c4a commit 1951fce
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
19 changes: 19 additions & 0 deletions src/ATen/native/xpu/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@

namespace at {

Tensor& XPUNativeFunctions::eye_out(int64_t n, Tensor& result) {
return XPUNativeFunctions::eye_out(n, n, result);
}

Tensor& XPUNativeFunctions::eye_out(int64_t n, int64_t m, Tensor& result) {
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);

result.resize_({n, m});
result.zero_();

int64_t sz = std::min<int64_t>(n, m);
int64_t stride = result.stride(0) + result.stride(1);

Tensor diag = result.as_strided({sz}, {stride});
diag.fill_(1);
return result;
}

Tensor XPUNativeFunctions::empty(
IntArrayRef size,
c10::optional<ScalarType> dtype_opt,
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"exp2.out",
"expm1.out",
"exponential_",
"eye.m_out",
"_fft_c2c",
"_fft_c2r",
"_fft_r2c",
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

_xpu_computation_op_list = [
"empty",
"eye",
"fill",
"zeros",
"zeros_like",
Expand Down
2 changes: 2 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ supported:
- exp_
- empty.memory_format
- empty_strided
- eye.out
- eye.m_out
- _efficientzerotensor
- complex.out
- clone
Expand Down

0 comments on commit 1951fce

Please sign in to comment.