Skip to content

Commit

Permalink
Add aten::_thnn_fused_gru_cell and _thnn_fused_lstm_cell (#926)
Browse files Browse the repository at this point in the history
- [x] thnn_fused_gru_cell_forward
- [x] thnn_fused_gru_cell_backward
- [x] thnn_fused_lstm_cell_forward
- [x] thnn_fused_lstm_cell_backward

---------

Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
2 people authored and ZhiweiYan-96 committed Dec 30, 2024
1 parent 214f33b commit 23575f5
Show file tree
Hide file tree
Showing 7 changed files with 1,081 additions and 30 deletions.
46 changes: 46 additions & 0 deletions src/ATen/native/xpu/RNN.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <ATen/ATen.h>
#include <ATen/native/xpu/sycl/RNNKernels.h>

namespace at::native {

std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_xpu(
const Tensor& input_gates,
const Tensor& hidden_gates,
const Tensor& cx,
const std::optional<Tensor>& input_bias_opt,
const std::optional<Tensor>& hidden_bias_opt) {
return native::xpu::_thnn_fused_lstm_cell_kernel(
input_gates, hidden_gates, cx, input_bias_opt, hidden_bias_opt);
}

std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward_xpu(
const std::optional<Tensor>& grad_hy_opt,
const std::optional<Tensor>& grad_cy_opt,
const Tensor& cx,
const Tensor& cy,
const Tensor& workspace,
bool has_bias) {
return native::xpu::_thnn_fused_lstm_cell_backward_kernel(
grad_hy_opt, grad_cy_opt, cx, cy, workspace, has_bias);
}

std::tuple<at::Tensor, at::Tensor> _thnn_fused_gru_cell_xpu(
const Tensor& input_gates,
const Tensor& hidden_gates,
const Tensor& hx,
const std::optional<at::Tensor>& input_bias,
const std::optional<at::Tensor>& hidden_bias) {
return native::xpu::_thnn_fused_gru_cell_kernel(
input_gates, hidden_gates, hx, input_bias, hidden_bias);
}

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
_thnn_fused_gru_cell_backward_xpu(
const Tensor& grad_hy,
const Tensor& workspace,
bool has_bias) {
return native::xpu::_thnn_fused_gru_cell_backward_kernel(
grad_hy, workspace, has_bias);
}

} // namespace at::native
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"ormqr",
"_scaled_dot_product_efficient_attention",
"_scaled_mm",
"_thnn_fused_gru_cell",
"_to_sparse_csr",
"triangular_solve.X",
"_upsample_bilinear2d_aa.out",
Expand Down
Loading

0 comments on commit 23575f5

Please sign in to comment.