11
11
#include < c10/util/irange.h>
12
12
#include < torch/custom_class.h>
13
13
#include < torch/library.h>
14
+ #include < ATen/Config.h>
14
15
15
16
#ifndef AT_PER_OPERATOR_HEADERS
16
17
#include < ATen/Functions.h>
50
51
#include < ATen/ops/tanh_backward.h>
51
52
#include < ATen/ops/zeros_like.h>
52
53
#include < ATen/ops/zeros_like_ops.h>
53
-
54
54
#include < utility>
55
55
#endif
56
56
@@ -69,6 +69,17 @@ bool use_miopen(const at::Tensor& input, const double dropout_state) {
69
69
return is_miopen_acceptable;
70
70
}
71
71
72
+ bool use_mkldnn (const Tensor& input) {
73
+ #if AT_MKLDNN_ENABLED()
74
+ if (!at::globalContext ().userEnabledMkldnn ()) {
75
+ return false ;
76
+ }
77
+ return input.options ().backend () == at::Backend::CPU &&
78
+ (input.scalar_type () == kFloat || input.scalar_type () == kBFloat16 );
79
+ #endif
80
+ return false ;
81
+ }
82
+
72
83
template <typename T>
73
84
using pair_of = std::pair<T, T>;
74
85
@@ -1409,6 +1420,7 @@ DEFINE_DISPATCH(lstm_cudnn_stub);
1409
1420
DEFINE_DISPATCH (lstm_packed_cudnn_stub);
1410
1421
DEFINE_DISPATCH (lstm_miopen_stub);
1411
1422
DEFINE_DISPATCH (lstm_packed_miopen_stub);
1423
+ DEFINE_DISPATCH (lstm_mkldnn_stub);
1412
1424
REGISTER_NO_CPU_DISPATCH (lstm_cudnn_stub);
1413
1425
REGISTER_NO_CPU_DISPATCH (lstm_packed_cudnn_stub);
1414
1426
REGISTER_NO_CPU_DISPATCH (lstm_miopen_stub);
@@ -1447,6 +1459,23 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
1447
1459
}
1448
1460
}
1449
1461
1462
+ if (use_mkldnn (_input)) {
1463
+ if (!has_projections) {
1464
+ if (hx[0 ].unsafeGetTensorImpl ()->has_symbolic_sizes_strides ()) {
1465
+ TORCH_WARN_ONCE (
1466
+ " LSTM with symbolic sizes and strides is not supported with oneDNN. Using default implementation." );
1467
+ } else {
1468
+ Tensor output, hy, cy;
1469
+ lstm_mkldnn_stub (_input.device ().type (), output, hy, cy,_input, hx, _params, has_biases,
1470
+ num_layers, dropout_p, train, bidirectional, batch_first);
1471
+ return std::make_tuple (std::move (output), std::move (hy), std::move (cy));
1472
+ }
1473
+ } else {
1474
+ TORCH_WARN_ONCE (
1475
+ " LSTM with projections is not supported with oneDNN. Using default implementation." );
1476
+ }
1477
+ }
1478
+
1450
1479
check_attributes (_input, _params, hx);
1451
1480
auto input = batch_first ? _input.transpose (0 , 1 ) : _input;
1452
1481
auto params = gather_params (_params, has_biases, has_projections);
0 commit comments