Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into penghuic/LNL_BMG_skip
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng committed Dec 20, 2024
2 parents c960e9a + 9ed0a1a commit ea6084b
Show file tree
Hide file tree
Showing 9 changed files with 1,083 additions and 36 deletions.
46 changes: 46 additions & 0 deletions src/ATen/native/xpu/RNN.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include <ATen/ATen.h>
#include <ATen/native/xpu/sycl/RNNKernels.h>

namespace at::native {

std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_xpu(
const Tensor& input_gates,
const Tensor& hidden_gates,
const Tensor& cx,
const std::optional<Tensor>& input_bias_opt,
const std::optional<Tensor>& hidden_bias_opt) {
return native::xpu::_thnn_fused_lstm_cell_kernel(
input_gates, hidden_gates, cx, input_bias_opt, hidden_bias_opt);
}

std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward_xpu(
const std::optional<Tensor>& grad_hy_opt,
const std::optional<Tensor>& grad_cy_opt,
const Tensor& cx,
const Tensor& cy,
const Tensor& workspace,
bool has_bias) {
return native::xpu::_thnn_fused_lstm_cell_backward_kernel(
grad_hy_opt, grad_cy_opt, cx, cy, workspace, has_bias);
}

std::tuple<at::Tensor, at::Tensor> _thnn_fused_gru_cell_xpu(
const Tensor& input_gates,
const Tensor& hidden_gates,
const Tensor& hx,
const std::optional<at::Tensor>& input_bias,
const std::optional<at::Tensor>& hidden_bias) {
return native::xpu::_thnn_fused_gru_cell_kernel(
input_gates, hidden_gates, hx, input_bias, hidden_bias);
}

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
_thnn_fused_gru_cell_backward_xpu(
const Tensor& grad_hy,
const Tensor& workspace,
bool has_bias) {
return native::xpu::_thnn_fused_gru_cell_backward_kernel(
grad_hy, workspace, has_bias);
}

} // namespace at::native
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"lu_unpack.out",
"ormqr",
"_scaled_mm",
"_thnn_fused_gru_cell",
"_to_sparse_csr",
"triangular_solve.X",
"_validate_compressed_sparse_indices",
Expand Down
Loading

0 comments on commit ea6084b

Please sign in to comment.