Skip to content

Commit

Permalink
Add aten::_masked_softmax and its variants (#930)
Browse files Browse the repository at this point in the history
- [x] _masked_softmax
- [x] _masked_softmax_backward
  • Loading branch information
xytintel authored Oct 29, 2024
1 parent f69c52f commit 2339e13
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 73 deletions.
16 changes: 16 additions & 0 deletions src/ATen/native/xpu/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,20 @@ TORCH_IMPL_FUNC(log_softmax_xpu_out)
xpu::_log_softmax_kernel(input, dim, half_to_float, output);
}

Tensor masked_softmax_xpu(
const Tensor& input_,
const Tensor& mask_,
const std::optional<int64_t> dim_,
const std::optional<int64_t> mask_type_) {
return xpu::masked_softmax_kernel(input_, mask_, dim_, mask_type_);
}

Tensor masked_softmax_backward_xpu(
const Tensor& grad_,
const Tensor& output_,
const Tensor& mask_,
const std::optional<int64_t> dim_) {
return xpu::masked_softmax_backward_kernel(grad_, output_, mask_, dim_);
}

} // namespace at::native
Loading

0 comments on commit 2339e13

Please sign in to comment.