Skip to content

Commit

Permalink
Add aten::trace (#446)
Browse files Browse the repository at this point in the history
Add aten::trace

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Jul 4, 2024
1 parent 662ab6a commit 41ac6a3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/TriangluarOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,10 @@ Tensor& XPUNativeFunctions::triu_(Tensor& self, int64_t diagonal) {
xpu::check_inplace(self, self.sizes(), self.options());
return triu_out(self, diagonal, self);
}

Tensor XPUNativeFunctions::trace(const Tensor& self) {
TORCH_CHECK(self.dim() == 2, "expected a matrix");
return self.diagonal().sum();
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"topk.values",
"_to_sparse",
"_to_sparse_csr",
"trace",
"triangular_solve.X",
"tril_indices",
"triu_indices",
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"view_as_real",
"view_as_complex",
"view",
"trace",
"resize_",
"resize_as_",
"add",
Expand Down
1 change: 1 addition & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ supported:
- _cdist_forward
- _pin_memory
- is_pinned
- trace
- reflection_pad2d
- reflection_pad2d.out
- reflection_pad2d_backward
Expand Down

0 comments on commit 41ac6a3

Please sign in to comment.