Skip to content

Commit 94a7c01

Browse files
yanbing-jpytorchmergebot
authored andcommitted
Enable oneDNN implementation in LSTM op (pytorch#91158)
### Description This PR is to enable oneDNN implementation in LSTM op to improve the performance of it. Both FP32 and BF16 are supported. ### Performance improvement In CPX 28C, with setting iomp and jemalloc. We choose 8 LSTM input options (including input_size, hidden_size, num_layers, bidirectional, bias, batch_first, dropout, batch_size, seq_len), and the final option is a real input from train-clean-100 in LibriSpeech dataset. The performance improvements are shown in the following figures. We can see that LSTM with oneDNN implementation can perform better than the original. In single socket: ![image](https://user-images.githubusercontent.com/61222868/211182994-833debec-518a-4b35-8504-6b0fadb17930.png) ![image](https://user-images.githubusercontent.com/61222868/211183012-31e1253f-2c60-4c92-a656-c239a971b453.png) In single core: ![image](https://user-images.githubusercontent.com/61222868/211183017-186e5d47-cb9a-4c1e-914f-fa718e769f1c.png) ![image](https://user-images.githubusercontent.com/61222868/211183022-53266857-5a9e-4a95-b300-33fa34811d08.png) Pull Request resolved: pytorch#91158 Approved by: https://github.com/jgong5, https://github.com/malfet
1 parent a41f00e commit 94a7c01

11 files changed

+804
-1
lines changed

aten/src/ATen/autocast_mode.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
506506
KERNEL_CPU2(_convolution, deprecated, lower_precision_fp)
507507
KERNEL_CPU(matmul, lower_precision_fp)
508508
KERNEL_CPU(conv_tbc, lower_precision_fp)
509+
KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp)
509510

510511
// fp32 cast policy
511512
KERNEL_CPU(conv_transpose1d, fp32)

aten/src/ATen/native/RNN.cpp

+30-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <c10/util/irange.h>
1212
#include <torch/custom_class.h>
1313
#include <torch/library.h>
14+
#include <ATen/Config.h>
1415

1516
#ifndef AT_PER_OPERATOR_HEADERS
1617
#include <ATen/Functions.h>
@@ -50,7 +51,6 @@
5051
#include <ATen/ops/tanh_backward.h>
5152
#include <ATen/ops/zeros_like.h>
5253
#include <ATen/ops/zeros_like_ops.h>
53-
5454
#include <utility>
5555
#endif
5656

@@ -69,6 +69,17 @@ bool use_miopen(const at::Tensor& input, const double dropout_state) {
6969
return is_miopen_acceptable;
7070
}
7171

72+
bool use_mkldnn(const Tensor& input) {
73+
#if AT_MKLDNN_ENABLED()
74+
if (!at::globalContext().userEnabledMkldnn()) {
75+
return false;
76+
}
77+
return input.options().backend() == at::Backend::CPU &&
78+
(input.scalar_type() == kFloat || input.scalar_type() == kBFloat16);
79+
#endif
80+
return false;
81+
}
82+
7283
template<typename T>
7384
using pair_of = std::pair<T, T>;
7485

@@ -1409,6 +1420,7 @@ DEFINE_DISPATCH(lstm_cudnn_stub);
14091420
DEFINE_DISPATCH(lstm_packed_cudnn_stub);
14101421
DEFINE_DISPATCH(lstm_miopen_stub);
14111422
DEFINE_DISPATCH(lstm_packed_miopen_stub);
1423+
DEFINE_DISPATCH(lstm_mkldnn_stub);
14121424
REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub);
14131425
REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub);
14141426
REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub);
@@ -1447,6 +1459,23 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
14471459
}
14481460
}
14491461

1462+
if (use_mkldnn(_input)) {
1463+
if (!has_projections) {
1464+
if (hx[0].unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
1465+
TORCH_WARN_ONCE(
1466+
"LSTM with symbolic sizes and strides is not supported with oneDNN. Using default implementation.");
1467+
} else {
1468+
Tensor output, hy, cy;
1469+
lstm_mkldnn_stub(_input.device().type(), output, hy, cy,_input, hx, _params, has_biases,
1470+
num_layers, dropout_p, train, bidirectional, batch_first);
1471+
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
1472+
}
1473+
} else {
1474+
TORCH_WARN_ONCE(
1475+
"LSTM with projections is not supported with oneDNN. Using default implementation.");
1476+
}
1477+
}
1478+
14501479
check_attributes(_input, _params, hx);
14511480
auto input = batch_first ? _input.transpose(0, 1) : _input;
14521481
auto params = gather_params(_params, has_biases, has_projections);

aten/src/ATen/native/RNN.h

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, co
1212

1313
DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub);
1414
DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub);
15+
DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub);
1516
DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub);
1617
DECLARE_DISPATCH(rnn_fn, gru_miopen_stub);
1718
DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub);

aten/src/ATen/native/mkldnn/MKLDNNCommon.h

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ namespace at { namespace native {
1010

1111
// Mapping ScalarType to ideep tensor data_type
1212
TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type);
13+
static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
14+
return get_mkldnn_dtype(t.scalar_type());
15+
}
1316

1417
// Construct aten MKL-DNN tensor given an ideep tensor
1518
TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional<ScalarType> dtype, c10::optional<Device> device);

0 commit comments

Comments
 (0)