diff --git a/src/include/miopen/rnn.hpp b/src/include/miopen/rnn.hpp index 00aadbabdf..c5a29299c1 100644 --- a/src/include/miopen/rnn.hpp +++ b/src/include/miopen/rnn.hpp @@ -261,6 +261,21 @@ struct RNNDescriptor : miopenRNNDescriptor Data_t reserveSpace, size_t reserveSpaceSize) const; + void RNNForwardTrainingTanhRelu(Handle& handle, + std::vector& seq_array, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& yDesc, + Data_t y, + const TensorDescriptor& hyDesc, + Data_t hy, + Data_t reserveSpace, + size_t reserveSpaceSize) const; + void RNNForwardInference(Handle& handle, int seqLen, c_array_view xDesc, @@ -462,6 +477,21 @@ struct RNNDescriptor : miopenRNNDescriptor ConstData_t reserveSpace, size_t reserveSpaceSize) const; + void RNNBackwardDataPackedTensorsRelu(Handle& handle, + int seqLen, + c_array_view dyDesc, + ConstData_t dy, + ConstData_t dhy, + ConstData_t w, + c_array_view dxDesc, + Data_t dx, + const TensorDescriptor& dhxDesc, + Data_t dhx, + Data_t workSpace, + size_t workSpaceSize, + Data_t reserveSpace, + size_t reserveSpaceSize) const; + void RNNForwardTrainingPackedTensors(Handle& handle, int seqLen, c_array_view xDesc, diff --git a/src/include/miopen/rnn_util.hpp b/src/include/miopen/rnn_util.hpp index 84e453358b..5b65de55f1 100644 --- a/src/include/miopen/rnn_util.hpp +++ b/src/include/miopen/rnn_util.hpp @@ -32,9 +32,58 @@ #include #include #include +#include namespace miopen { +enum class RnnDirection +{ + Forward = 0, + Backward = 1 +}; + +struct RnnBatches +{ + int at(int time, RnnDirection direction) const { return batches.at(cur_time(time, direction)); } + + int next(int time, RnnDirection direction) const + { + return batches.at(next_time(time, direction)); + } + + int prev(int time, RnnDirection direction) const + { + return batches.at(prev_time(time, direction)); + } + + void push_back(int batch) { batches.push_back(batch); } + + RnnBatches(std::vector& input) : batches(input){}; + RnnBatches(){}; + + int back() const { return batches.back(); } + +private: + int cur_time(int time, RnnDirection direction) const + { + return direction == RnnDirection::Forward ? time : batches.size() - time - 1; + } + + int next_time(int time, RnnDirection direction) const + { + return direction == RnnDirection::Forward ? cur_time(time, direction) + 1 + : cur_time(time, direction) - 1; + } + + int prev_time(int time, RnnDirection direction) const + { + return direction == RnnDirection::Forward ? cur_time(time, direction) - 1 + : cur_time(time, direction) + 1; + } + + std::vector batches; +}; + #if MIOPEN_BACKEND_HIP inline void RNNProfilingBegin(const miopen::Handle& handle, miopen::HipEventPtr& start, @@ -121,6 +170,129 @@ void LSTMBackwardHiddenStateUpdate(const Handle& handle, std::size_t dhidden_offset, std::size_t f_offset_pre); +struct ReluWeightOffsets +{ +public: + ReluWeightOffsets(int input_vector_sz, + int hidden_vec_sz, + int layers_cnt, + int bias_mode, + int bi, + int nHiddenTensorsPerLayer) + : weight_stride(hidden_vec_sz * bi * nHiddenTensorsPerLayer), + in_vec_sz(input_vector_sz), + h_vec_sz(hidden_vec_sz), + num_layers(layers_cnt), + bi_scale(bi), + bias_count(bias_mode) + { + } + + int input_weight_offset(int layer) const + { + return layer == 0 ? 0 + : first_layer_offset() + + (h_vec_sz + h_vec_sz * bi_scale) * weight_stride * (layer - 1); + } + + int hidden_weight_offset(int layer, RnnDirection reverse) const + { + return layer == 0 ? input_weight_offset(layer) + in_vec_sz * weight_stride + + static_cast(reverse) * h_vec_sz * h_vec_sz + : input_weight_offset(layer) + bi_scale * h_vec_sz * weight_stride + + static_cast(reverse) * h_vec_sz * h_vec_sz; + } + + size_t bias_stride() const { return static_cast(h_vec_sz) * bi_scale; } + + int bias_off() const + { + return first_layer_offset() + + (h_vec_sz * bi_scale + h_vec_sz) * (num_layers - 1) * weight_stride; + } + + int + bias_off(int layer_id, int bias_id, RnnDirection direction) const + { + return bias_off() + bias_count * layer_id * weight_stride + bias_id * bias_stride() + + static_cast(direction) * h_vec_sz; + } + int weight_stride; + +private: + const int in_vec_sz, h_vec_sz; + +public: + const int num_layers; + const int bi_scale = 1; + const int bias_count = 0; + + int first_layer_offset() const { return (in_vec_sz + h_vec_sz) * weight_stride; } +}; + +struct ReluReserveBufferOffsets +{ + struct RBuffHelper + { + int element, save_point, batch; + size_t layer, table; + }; + +private: + auto Reserve_Buffer_strides(int save_point_sz, int batches_per_l, int layers_cnt) const + { + const auto element_st = 1; + const auto save_point_st = element_st * save_point_sz; + const auto batch_st = save_point_st; + const auto layer_st = static_cast(batch_st) * batches_per_l; + const auto table_st = layers_cnt * layer_st; + + return RBuffHelper{element_st, save_point_st, batch_st, layer_st, table_st}; + } + +public: + ReluReserveBufferOffsets( + int hidden_vec_size, int layers_cnt, int batches_per_l, int bi_scale, int workspace_scale) + : hidden_size(hidden_vec_size), + batches_per_layer(batches_per_l), + save_point_size(hidden_vec_size * bi_scale * workspace_scale), + layers(layers_cnt), + strides(Reserve_Buffer_strides(save_point_size, batches_per_l, layers_cnt)) + { + } + + size_t layer_offset(int layer_id) const + { + return static_cast(layer_id) * strides.layer; + } + + size_t layer_stride() const { return strides.layer; } + + int gemm_write_size() const { return strides.save_point; } + + size_t gemm_write_stride() const { return strides.batch; } + + size_t gemm_write_offset(int layer_id, int batch_id, RnnDirection reverse) const + { + return layer_offset(layer_id) + static_cast(gemm_write_stride()) * batch_id + + static_cast(reverse) * hidden_size; + } + + size_t hidden_offset(int layer_id, int batch_id, RnnDirection reverse) const + { + return strides.table + gemm_write_offset(layer_id, batch_id, reverse); + } + +private: + const int hidden_size; + +public: + const int batches_per_layer; + const int save_point_size; + const int layers; + const RBuffHelper strides; +}; + struct RNNTensorPaddingConverter { static void ConvertTensorData(const Handle& handle, diff --git a/src/ocl/rnnocl.cpp b/src/ocl/rnnocl.cpp index 131d69db46..89695382b1 100644 --- a/src/ocl/rnnocl.cpp +++ b/src/ocl/rnnocl.cpp @@ -40,410 +40,402 @@ MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_RNNFWD_exp) namespace miopen { -void RNNDescriptor::RNNForwardTraining_MS(Handle& handle, - std::vector& seq_array, - const TensorDescriptor& xDesc, - ConstData_t x, - const TensorDescriptor& hxDesc, - ConstData_t hx, - ConstData_t cx, - const TensorDescriptor& wDesc, - ConstData_t w, - const TensorDescriptor& yDesc, - Data_t y, - Data_t hy, - Data_t cy, - Data_t reserveSpace, - size_t reserveSpaceSize) const +void RNNDescriptor::RNNForwardTrainingTanhRelu(Handle& handle, + std::vector& seq_array, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& yDesc, + Data_t y, + const TensorDescriptor& hyDesc, + Data_t hy, + Data_t reserveSpace, + size_t reserveSpaceSize) const { #if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP - std::vector in_n; - int in_vec = xDesc.GetLengths()[1]; // input vector size - int out_vec = yDesc.GetLengths()[1]; // output vector size + int seq_len = seq_array.size(); + if(seq_len == 0) + return; + + RnnBatches batches(seq_array); + RnnBatches bacc_per_time; + auto rnn_data_type = wDesc.GetType(); + + float beta = 0; + + int in_vec_size = xDesc.GetLengths()[1]; + int out_vec_size = yDesc.GetLengths()[1]; + int biNumLayers = hyDesc.GetLengths()[0]; - int seq_len = seq_array.size(); int max_batch = seq_array[0]; int hidden_size; - std::tie(std::ignore, max_batch, hidden_size) = miopen::tien<3>(hxDesc.GetLengths()); - auto extra_stream_cnt = 2; - handle.ReserveExtraStreamsInPool(extra_stream_cnt); + int total_batch_size = 0; - auto root_stream_id = 0; - std::vector stream_pull; - for(int i = 0; i <= extra_stream_cnt; i++) + for(int i = 0; i < seq_len; i++) { - handle.SetStreamFromPool(i); - stream_pull.push_back(handle.GetStream()); + bacc_per_time.push_back(total_batch_size); + total_batch_size += seq_array[i]; } - handle.SetStreamFromPool(root_stream_id); - - int total_batch_size = 0; - std::vector bacc_per_time(seq_len + 1); + int bi = dirMode != 0u ? 2 : 1; - for(int i = 0; i < seq_len; i++) + if(in_vec_size <= 0 || hidden_size <= 0 || max_batch <= 0 || biNumLayers <= 0 || + out_vec_size <= 0 || seq_len == 0) { - bacc_per_time[i] = total_batch_size; - total_batch_size += seq_array[i]; - in_n.push_back(seq_array[i]); + MIOPEN_THROW(miopenStatusBadParm); } - bacc_per_time[seq_len] = total_batch_size; - const struct + const auto sp_tensor_size = reserveSpaceSize / GetTypeSize(rnn_data_type); + + auto sp_desc = miopen::TensorDescriptor( + rnn_data_type, {1, 1, sp_tensor_size}, {sp_tensor_size, sp_tensor_size, 1}); + + // Clear reserveSpace buffer + // + SetTensor(handle, sp_desc, reserveSpace, &beta); + + if(hy != nullptr) { - int batch; - } InBuff_strides{in_vec}; + const auto hy_tensor_size = biNumLayers * max_batch * hidden_size; + auto hy_desc = miopen::TensorDescriptor( + rnn_data_type, {1, 1, hy_tensor_size}, {hy_tensor_size, hy_tensor_size, 1}); + // Clear hy buffer + // + SetTensor(handle, hy_desc, hy, &beta); + } - auto get_HxBuff_offset = [&](int layer_id) { - return layer_id * (static_cast(hidden_size) * max_batch); - }; + auto get_HxBuff_offset = + [bi, hidden_size, max_batch](int layer_id, int batch_id, RnnDirection reverse) { + return (static_cast(hidden_size) * (max_batch)) * + (bi * layer_id + static_cast(reverse)) + + (size_t)hidden_size * batch_id; + }; - int gates_cnt = 4; - int save_points_cnt = 6; + ReluWeightOffsets WeiBuf( + in_vec_size, hidden_size, nLayers, biasMode * 2, bi, nHiddenTensorsPerLayer); - struct WeightsBufferHelper + ReluReserveBufferOffsets RBuff(hidden_size, nLayers, total_batch_size, bi, workspaceScale); + + ActivationDescriptor activDesc; + + if(rnnMode == miopenRNNRELU) { - private: - auto hidden_xinput_size(int hidden_sz, int bidirect_mode) const - { - if(bidirect_mode == 0) - return hidden_sz; - MIOPEN_THROW("execution failure: bidirect is not supported by this solver"); - } + activDesc = {miopenActivationRELU, 1, 0, 1}; + } + else if(rnnMode == miopenRNNTANH) + { + activDesc = {miopenActivationTANH, 1, 1, 1}; + } - auto matrix_lin_layer_size(int input_vector_sz, int hidden_vec_sz, int gates) const - { - return (input_vector_sz + hidden_vec_sz) * hidden_vec_sz * gates; - } - size_t bias_start_offset(int input_vector_sz, - int hidden_vec_sz, - int layers_cnt, - int gates, - int bidirect_mode) const + auto call_relu_tan_input_gemm = [this, + &RBuff, + &WeiBuf, + &in_vec_size, + &handle, + &xDesc, + reserveSpace, + x, + w, + hidden_size, + rnn_data_type, + bi](int layer) { + if(inputMode == miopenRNNskip && layer == 0) { - if(bidirect_mode == 0) + auto x_desc = + miopen::TensorDescriptor(rnn_data_type, + {1, RBuff.batches_per_layer, hidden_size}, + {1, static_cast(RBuff.batches_per_layer ) * in_vec_size, in_vec_size}); + auto ht_desc = + miopen::TensorDescriptor(rnn_data_type, + {1, RBuff.batches_per_layer, hidden_size}, + {RBuff.layer_stride(), RBuff.gemm_write_stride(), 1}); + + for(int gi = 0; gi < nHiddenTensorsPerLayer * bi; gi++) { - return matrix_lin_layer_size(input_vector_sz, hidden_vec_sz, gates) + - static_cast(hidden_vec_sz + hidden_xinput_size(hidden_vec_sz, 0)) * - hidden_vec_sz * static_cast(layers_cnt - 1) * gates; + CopyTensor(handle, x_desc, x, ht_desc, reserveSpace, 0, gi * hidden_size); } - - MIOPEN_THROW("execution failure: bidirect is not supported by this solver"); + return; } - public: - WeightsBufferHelper( - int input_vector_sz, int hidden_vec_sz, int layers_cnt, int bias_mode, int gates) - : in_vec(input_vector_sz), - h_vec(hidden_vec_sz), - x_in_vec(hidden_xinput_size(hidden_vec_sz, 0)), - layers(layers_cnt), - gates_cnt(gates), - bias_cnt(bias_mode), - matrix_normal_start_off(matrix_lin_layer_size(input_vector_sz, hidden_vec_sz, gates)), - bias_start_off( - bias_start_offset(input_vector_sz, hidden_vec_sz, layers_cnt, gates, 0)) - { - } + const int m = RBuff.batches_per_layer, n = RBuff.gemm_write_size(), + k = layer > 0 ? RBuff.gemm_write_size() : in_vec_size; - const int in_vec, h_vec; - const int x_in_vec; // for bidirect TODO + const int lda = layer > 0 ? RBuff.gemm_write_stride() : in_vec_size, ldb = k, + ldc = RBuff.gemm_write_stride(); - const int layers; - const int gates_cnt; - const int - bias_cnt; // 0 - no bisa; 1 - one bias; 2 - separate bias for x_vec and for hidden_vec - private: - const size_t matrix_normal_start_off; - const size_t bias_start_off; + const miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + m, + n, + k, + lda, + ldb, + ldc, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc.GetType(), + false}; - public: - auto get_matrix_x_size(int layer_id) const - { - return (layer_id > 0 ? x_in_vec : in_vec) * h_vec; - } - auto get_matrix_h_size() const { return h_vec * h_vec; } - auto get_matrix_layer_size(int layer_id) const - { - return get_matrix_x_size(layer_id) * gates_cnt + get_matrix_h_size() * gates_cnt; - } + const auto wx_offset = WeiBuf.input_weight_offset(layer); + const auto ht_offset = RBuff.layer_offset(layer); - size_t get_matrix_x_off(int layer_id) const - { - if(layer_id > 0) - { - return matrix_normal_start_off + - static_cast(layer_id - 1) * get_matrix_layer_size(layer_id); - } - else - { - return 0; - } - }; + const auto xt_offset = + layer > 0 ? RBuff.hidden_offset(layer - 1, 0, RnnDirection::Forward) : 0; - size_t get_matrix_h_off(int layer_id) const + const auto input_ptr = layer > 0 ? reserveSpace : x; + // Ht(t)^ = Whx(t)*x(t), t = 0:seq_len - 1 + // + const miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + input_ptr, + xt_offset, + w, + wx_offset, + reserveSpace, + ht_offset, + GemmBackend_t::rocblas); + if(gemm_status != miopenStatusSuccess) { - if(layer_id > 0) + if(gemm_status == miopenStatusNotImplemented) { - return get_matrix_x_off(layer_id) + - static_cast(h_vec * x_in_vec * gates_cnt); + MIOPEN_LOG_E("GEMM not implemented"); } else { - return get_matrix_x_off(layer_id) + static_cast(h_vec * in_vec) * gates_cnt; + MIOPEN_LOG_E("GEMM failed"); } - }; + } + }; - int bias_vector_size() const { return h_vec; } - int bias_vector_mul_gate() const { return bias_vector_size() * gates_cnt; } - int bias_stride() const { return bias_vector_mul_gate(); } + auto call_relu_tan_bias_add = [this, + &RBuff, + &WeiBuf, + &handle, + rnn_data_type, + reserveSpace, + w, + &hx, + &hidden_size, + &seq_len, + max_batch, + &batches, + &bacc_per_time](int layer) { + float alpha0 = 1; + float alpha1 = 1; + float beta = 0; + + auto bias_desc = miopen::TensorDescriptor(rnn_data_type, + {1, 1, WeiBuf.bias_stride()}, + {WeiBuf.bias_stride(), WeiBuf.bias_stride(), 1}); + + auto ht_desc = + miopen::TensorDescriptor(rnn_data_type, + {1, RBuff.batches_per_layer, WeiBuf.bias_stride()}, + {RBuff.layer_stride(), RBuff.gemm_write_stride(), 1}); + + // Ht(t)^ = Ht(t)^ + b, t = 0:seq_len - 1 + // + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + ht_desc, + reserveSpace, + &alpha1, + bias_desc, + w, + &beta, + ht_desc, + reserveSpace, + RBuff.layer_offset(layer), + WeiBuf.bias_off(layer, 0, RnnDirection::Forward), + RBuff.layer_offset(layer)); - size_t bias_relative_off(int layer_id, int bias_id) const + if(hx != nullptr) { - return static_cast(layer_id * bias_cnt + bias_id) * gates_cnt * h_vec; + // Ht(t)^ = H(t)^ + hx_bias, t = 0:seq_len - 1 + // + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + ht_desc, + reserveSpace, + &alpha1, + bias_desc, + w, + &beta, + ht_desc, + reserveSpace, + RBuff.layer_offset(layer), + WeiBuf.bias_off(layer, 1, RnnDirection::Forward), + RBuff.layer_offset(layer)); + return; } - size_t get_bias_off(int layer_id, int bias_id) const - { - return bias_start_off + bias_relative_off(layer_id, bias_id); - } + // hx == nullptr + // + if((RBuff.batches_per_layer - max_batch) <= 0) + return; - } WeiBuf(in_vec, hidden_size, nLayers, biasMode * 2, gates_cnt); + ht_desc = miopen::TensorDescriptor(rnn_data_type, + {1, RBuff.batches_per_layer - max_batch, hidden_size}, + {RBuff.layer_stride(), RBuff.gemm_write_stride(), 1}); - struct ReserveBufferHelper - { - struct RBuffHelper - { - int element, save_point, batch; - size_t layer; - }; + bias_desc = miopen::TensorDescriptor( + rnn_data_type, {1, 1, hidden_size}, {WeiBuf.bias_stride(), WeiBuf.bias_stride(), 1}); - private: - auto Reserve_Buffer_strides(int save_point_sz, - int batches_per_layer, - int save_points, - int bidirect_mode = 0) const - { - const auto element_st = 1; - const auto save_point_st = element_st * save_point_sz; - const auto batch_st = save_point_st * save_points; - const auto layer_st = static_cast(batch_st) * batches_per_layer; - if(bidirect_mode == 0) - return RBuffHelper{element_st, save_point_st, batch_st, layer_st}; - MIOPEN_THROW("execution failure: bidirect is not supported by this solver"); - } + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + ht_desc, + reserveSpace, + &alpha1, + bias_desc, + w, + &beta, + ht_desc, + reserveSpace, + RBuff.gemm_write_offset(layer, max_batch, RnnDirection::Forward), + WeiBuf.bias_off(layer, 1, RnnDirection::Forward), + RBuff.gemm_write_offset(layer, max_batch, RnnDirection::Forward), + true); - public: - enum save_point - { - F = 1, - I = 0, - G = 2, - O = 3, - St = 4, - Ht = 5 - }; + if(dirMode == 0u) + return; - ReserveBufferHelper(int hidden_vec_sz, - int save_point_sz, - int layers_cnt, - int batches_per_layer, - int save_points, - int gates_cnt) - : h_vec(hidden_vec_sz), - save_point_size(save_point_sz), - layers(layers_cnt), - batches(batches_per_layer), - save_points_cnt(save_points), - gates(gates_cnt), - strides(Reserve_Buffer_strides(save_point_sz, batches, save_points, 0)) + if(max_batch == batches.at(seq_len - 1, RnnDirection::Forward)) { + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + ht_desc, + reserveSpace, + &alpha1, + bias_desc, + w, + &beta, + ht_desc, + reserveSpace, + RBuff.gemm_write_offset(layer, 0, RnnDirection::Backward), + WeiBuf.bias_off(layer, 1, RnnDirection::Backward), + RBuff.gemm_write_offset(layer, 0, RnnDirection::Backward), + true); + return; } - const int h_vec; - const int save_point_size; // for bidirect TODO - - const int layers; - const int batches; - const int save_points_cnt; - const int gates; - const RBuffHelper strides; - - size_t layer_offset(int layer) const { return static_cast(layer) * strides.layer; } - auto layer_stride() const { return strides.layer; } - - auto gemm_write_size() const { return h_vec * gates; } - auto gemm_write_stride() const - { - return strides.batch; - } // save_point_size * save_points_cnt - - size_t gemm_write_relative_offset(int batch_id) const - { - return static_cast(gemm_write_stride()) * batch_id; - } - - size_t gemm_write_offset(int layer, int batch_id) const - { - return layer_offset(layer) + static_cast(gemm_write_stride()) * batch_id; - } - - auto ht_relative_offset() const { return save_point::Ht * save_point_size; } - - auto ct_relative_offset() const { return save_point::St * save_point_size; } - - auto get_gate_relative_offset(int gate_id) const { return gate_id * save_point_size; } - - size_t ht_offset(int layer_id, int batch_id) const + // Ht(t)^ = Ht(t)^ + bias2, t = 1:seq_len - 1, reverse direction + // + for(int ti = 0; ti < seq_len - 1; ti++) { - return layer_offset(layer_id) + gemm_write_relative_offset(batch_id) + - ht_relative_offset(); - } + auto ht_offset = RBuff.gemm_write_offset( + layer, bacc_per_time.at(ti, RnnDirection::Forward), RnnDirection::Backward); - size_t extra_save_point_offset(int layer_id, int batch_id) const - { - return (static_cast(batches) * layers * gemm_write_stride()) // all data offset - + (static_cast(batches) * layer_id) * h_vec + - static_cast(batch_id * h_vec); + ht_desc = miopen::TensorDescriptor( + rnn_data_type, + {1, batches.at(ti + 1, RnnDirection::Forward), hidden_size}, + {WeiBuf.bias_stride(), WeiBuf.bias_stride(), 1}); + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + ht_desc, + reserveSpace, + &alpha1, + bias_desc, + w, + &beta, + ht_desc, + reserveSpace, + ht_offset, + WeiBuf.bias_off(layer, 1, RnnDirection::Backward), + ht_offset, + true); } - - } RBuff(hidden_size, hidden_size, nLayers, total_batch_size, save_points_cnt, gates_cnt); - - auto call_x_gemm = [&RBuff, - &WeiBuf, - &InBuff_strides, - &bacc_per_time, - &handle, - &xDesc, - reserveSpace, - x, - w, - hidden_size, - in_vec](int layer, int start_time, int time_cnt, float beta_t = 1) { - const auto start_b = bacc_per_time[start_time]; - const auto batch_sz = bacc_per_time[start_time + time_cnt] - start_b; - - const int m = batch_sz, n = RBuff.gemm_write_size(), k = layer > 0 ? hidden_size : in_vec; - const int lda = layer > 0 ? RBuff.gemm_write_stride() : InBuff_strides.batch, ldb = k, - ldc = RBuff.gemm_write_stride(); - - const miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - m, - n, - k, - lda, - ldb, - ldc, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - beta_t, // beta - xDesc.GetType(), - false}; - - const auto wx_off = WeiBuf.get_matrix_x_off(layer); - const auto out_offset = RBuff.gemm_write_offset(layer, start_b); - - const auto x_in_offset = layer > 0 ? RBuff.ht_offset(layer - 1, start_b) - : static_cast(start_b * InBuff_strides.batch); - const auto in_ptr = layer > 0 ? reserveSpace : x; - - const miopenStatus_t gemm_status = CallGemm(handle, - gemm_desc, - in_ptr, - x_in_offset, - w, - wx_off, - reserveSpace, - out_offset, - GemmBackend_t::rocblas); - if(gemm_status != miopenStatusSuccess) - MIOPEN_THROW("GEMM execution failure"); }; - auto call_bias_add = [&RBuff, &WeiBuf, &handle, &wDesc, reserveSpace, w](int layer, - float beta_t = 0) { - float alpha0 = 1; - float alpha1 = 1; - const auto bias_stride = WeiBuf.bias_stride(); + auto call_relu_tan_hidden_gemm = [&RBuff, + &WeiBuf, + hidden_size, + &get_HxBuff_offset, + &handle, + &xDesc, + reserveSpace, + hx, + w, + &batches, + &bacc_per_time](int layer, int time, RnnDirection direction) { + if(time == 0 && hx == nullptr) + return; - const auto bias_desc = - miopen::TensorDescriptor(wDesc.GetType(), - std::vector{1, 1, WeiBuf.bias_vector_mul_gate()}, - std::vector{bias_stride, bias_stride, 1}); + const auto m = direction == RnnDirection::Forward ? batches.at(time, direction) + : time == 0 ? batches.at(time, direction) + : batches.prev(time, direction); - const auto hidden_interim_desc = miopen::TensorDescriptor( - wDesc.GetType(), - std::vector{1, RBuff.batches, WeiBuf.bias_vector_mul_gate()}, - std::vector{ - RBuff.batches * RBuff.gemm_write_stride(), RBuff.gemm_write_stride(), 1}); + const int n = hidden_size, k = hidden_size; - const auto RB_layer_out_off = RBuff.layer_offset(layer); - const auto w_bias_layer_start_off = WeiBuf.get_bias_off(layer, 0); + const int lda = (time != 0) ? RBuff.gemm_write_stride() : hidden_size; - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - hidden_interim_desc, - reserveSpace, // A - &alpha1, - bias_desc, - w, // B - &beta_t, - hidden_interim_desc, - reserveSpace, // C - RB_layer_out_off, // A offset - w_bias_layer_start_off, // B offset - RB_layer_out_off, // C offset - true); + const int ldb = hidden_size, ldc = RBuff.gemm_write_stride(); - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - hidden_interim_desc, - reserveSpace, - &alpha1, - bias_desc, - w, - &beta_t, - hidden_interim_desc, - reserveSpace, - RB_layer_out_off, - w_bias_layer_start_off + bias_stride, - RB_layer_out_off, - true); - }; + const auto ht_ptr = time > 0 ? reserveSpace : hx; - auto call_hx_gemm = [&RBuff, - &WeiBuf, - &get_HxBuff_offset, - &bacc_per_time, - &in_n, - &handle, - &xDesc, - reserveSpace, + if(time != 0 && direction == RnnDirection::Backward && hx != nullptr && + batches.at(time, direction) > batches.prev(time, direction)) + { + auto dbatches = batches.at(time, direction) - batches.prev(time, direction); + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + dbatches, + n, + k, + hidden_size, + ldb, + ldc, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc.GetType(), + false}; + + // Ht(t)^ = Ht(t)^ + Whx * hx, + // for batches = bacc_per_time.prev(time, direction) - dbatches : + // bacc_per_time.prev(time, direction), t = 1:seq_len - 1 + // + const miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, hx, + get_HxBuff_offset(layer, batches.prev(time, direction), direction), w, - hidden_size](int layer, int cur_time) { - const int m = in_n.at(cur_time), n = RBuff.gemm_write_size(), k = hidden_size; - - const int lda = (cur_time != 0) ? RBuff.gemm_write_stride() : hidden_size, - ldb = hidden_size, ldc = RBuff.gemm_write_stride(); - - const auto hx_ptr_offset = (cur_time == 0) - ? get_HxBuff_offset(layer) - : RBuff.ht_offset(layer, bacc_per_time[cur_time - 1]); + WeiBuf.hidden_weight_offset(layer, direction), + reserveSpace, + RBuff.gemm_write_offset( + layer, bacc_per_time.prev(time, direction) - dbatches, direction), + GemmBackend_t::rocblas); - if(cur_time == 0) - { - if(hx == nullptr) - return; + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } } const miopen::GemmDescriptor gemm_desc_hx = GemmDescriptor{false, @@ -464,1801 +456,1039 @@ void RNNDescriptor::RNNForwardTraining_MS(Handle& handle, xDesc.GetType(), false}; - const auto RB_layer_save_points_off = - RBuff.gemm_write_offset(layer, bacc_per_time[cur_time]); + const auto ht_offset = + (time == 0) + ? get_HxBuff_offset(layer, 0, direction) + : RBuff.hidden_offset(layer, bacc_per_time.prev(time, direction), direction); - const auto hx_ptr = cur_time > 0 ? reserveSpace : hx; + const auto not_activated_ht_offset = + RBuff.gemm_write_offset(layer, bacc_per_time.at(time, direction), direction); + // Ht(t)^ = Ht(t)^ + Whh * Ht(t-1) + // const miopenStatus_t gemm_status = CallGemm(handle, gemm_desc_hx, - hx_ptr, - hx_ptr_offset, + ht_ptr, + ht_offset, w, - WeiBuf.get_matrix_h_off(layer), + WeiBuf.hidden_weight_offset(layer, direction), reserveSpace, - RB_layer_save_points_off, + not_activated_ht_offset, GemmBackend_t::rocblas); - if(gemm_status != miopenStatusSuccess) - MIOPEN_THROW("GEMM execution failure"); + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } }; - auto call_hidden_state_update = [&RBuff, - &get_HxBuff_offset, - &bacc_per_time, - &in_n, - &handle, - &wDesc, - reserveSpace, - cx, - max_batch, - hidden_size](int layer_id, int time_id) { - auto RB_layer_save_points_off = - RBuff.layer_offset(layer_id) + RBuff.gemm_write_relative_offset(bacc_per_time[time_id]); - - auto is_seq_begin = time_id == 0; + auto call_relu_tan_hidden_state_update = + [&RBuff, + hidden_size, + &handle, + rnn_data_type, + reserveSpace, + &activDesc, + &batches, + &bacc_per_time](int layer_id, int time, RnnDirection direction) { + float alpha = 1, beta = 0; + + auto ht_desc = + miopen::TensorDescriptor(rnn_data_type, + {1, batches.at(time, direction), hidden_size}, + {RBuff.layer_stride(), RBuff.gemm_write_stride(), 1}); + + const auto ht_not_activated_offset = + RBuff.gemm_write_offset(layer_id, bacc_per_time.at(time, direction), direction); + + const auto hidden_offset = + RBuff.hidden_offset(layer_id, bacc_per_time.at(time, direction), direction); + + // Ht(t) = @(Ht^(t)) + // + activDesc.Forward(handle, + &alpha, + // input tensor descriptor + ht_desc, + // input pointer + reserveSpace, + &beta, + // output tensor descriptor + ht_desc, + // output pointer + reserveSpace, + // input tensor offset + ht_not_activated_offset, + // output tensor offset + hidden_offset); + }; - const int direction = 0; - const int cur_batch = in_n.at(time_id), use_batch = in_n.at(time_id); + auto call_relu_tan_update_output = [&RBuff, + &get_HxBuff_offset, + hidden_size, + &handle, + rnn_data_type, + reserveSpace, + hy, + max_batch, + &bacc_per_time, + seq_len, + &batches](int layer_id, int time, RnnDirection direction) { + if(hy == nullptr) + return; + + auto dbatches = time == seq_len - 1 + ? batches.at(time, direction) + : batches.at(time, direction) - batches.next(time, direction); + + if(dbatches <= 0) + return; + + const std::vector hcy_src_stride{ + RBuff.layer_stride(), static_cast(RBuff.gemm_write_stride()), 1}; - const int hy_stride = RBuff.gemm_write_stride(), wei_len = RBuff.gemm_write_size(), - wei_stride = RBuff.gemm_write_size(); + const std::vector hcy_dst_stride{ + static_cast(hidden_size * max_batch), static_cast(hidden_size), 1}; - const size_t cx_offset = get_HxBuff_offset(layer_id); + auto batch_id_relative = batches.at(time, direction) - dbatches; - const size_t i_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(0), - f_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(1), - o_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(2), - c_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(3); + auto batch_id_abs = time == seq_len - 1 + ? bacc_per_time.at(time, direction) + : bacc_per_time.at(time, direction) + batches.next(time, direction); - const size_t cell_offset = RB_layer_save_points_off + RBuff.ct_relative_offset(), - hidden_offset = RB_layer_save_points_off + RBuff.ht_relative_offset(); + const std::vector hcy_copy_size{ + 1, static_cast(dbatches), static_cast(hidden_size)}; - const size_t cell_offset_pre = - (time_id == 0) ? 0 - : RBuff.layer_offset(layer_id) + - RBuff.gemm_write_relative_offset(bacc_per_time[time_id - 1]) + - RBuff.ct_relative_offset(); + auto src_desc = miopen::TensorDescriptor(rnn_data_type, hcy_copy_size, hcy_src_stride); + auto dst_desc = miopen::TensorDescriptor(rnn_data_type, hcy_copy_size, hcy_dst_stride); - const size_t activ_cell_offset = - RBuff.extra_save_point_offset(layer_id, bacc_per_time[time_id]); + CopyTensor(handle, + src_desc, + reserveSpace, + dst_desc, + hy, + RBuff.hidden_offset(layer_id, batch_id_abs, direction), + get_HxBuff_offset(layer_id, batch_id_relative, direction)); + }; - LSTMForwardHiddenStateUpdate(handle, - wDesc.GetType(), - false, - is_seq_begin, - direction, - max_batch, - cur_batch, - use_batch, - - hidden_size, - hy_stride, - wei_len, - wei_stride, - cx, - cx_offset, - reserveSpace, - i_offset, - f_offset, - o_offset, - c_offset, - cell_offset, - cell_offset_pre, - activ_cell_offset, - hidden_offset); - }; + for(int layer_id = 0; layer_id < nLayers; layer_id++) + { + call_relu_tan_input_gemm(layer_id); + if(biasMode != 0u) + call_relu_tan_bias_add(layer_id); - auto call_hy_cy_update = [&RBuff, - &get_HxBuff_offset, - &bacc_per_time, - &in_n, - &handle, - &wDesc, - reserveSpace, - hy, - cy, - max_batch, - hidden_size, - seq_len](int layer_id) { - if(hy != nullptr || (cy != nullptr)) + for(int time = 0; time < seq_len; time++) { - auto hcy_layer_offset = get_HxBuff_offset(layer_id); - - const std::vector hcy_src_stride{ - RBuff.layer_stride(), static_cast(RBuff.gemm_write_stride()), 1}; - const std::vector hcy_dst_stride{ - static_cast(hidden_size * max_batch), static_cast(hidden_size), 1}; + call_relu_tan_hidden_gemm(layer_id, time, RnnDirection::Forward); - for(int time_i = seq_len - 1; time_i >= 0; time_i--) - { - auto copy_batch = (time_i == seq_len - 1) ? in_n.at(time_i) - : in_n.at(time_i) - in_n.at(time_i + 1); - if(copy_batch > 0) - { - auto batch_id_relative = in_n.at(time_i) - copy_batch; - auto batch_id_abs = bacc_per_time[time_i] + batch_id_relative; + call_relu_tan_hidden_state_update(layer_id, time, RnnDirection::Forward); - auto hcy_batch_offset = batch_id_relative * hidden_size; + if(dirMode == 0u) + continue; - auto src_batch_offset = RBuff.layer_offset(layer_id) + - RBuff.gemm_write_relative_offset(batch_id_abs); + call_relu_tan_hidden_gemm(layer_id, time, RnnDirection::Backward); - const std::vector hcy_copy_size{ - 1, static_cast(copy_batch), static_cast(hidden_size)}; + call_relu_tan_hidden_state_update(layer_id, time, RnnDirection::Backward); + } - auto src_desc = - miopen::TensorDescriptor(wDesc.GetType(), hcy_copy_size, hcy_src_stride); - auto dst_desc = - miopen::TensorDescriptor(wDesc.GetType(), hcy_copy_size, hcy_dst_stride); + for(int time = seq_len - 1; time >= 0; time--) + { + call_relu_tan_update_output(layer_id, time, RnnDirection::Forward); - if(hy != nullptr) - { - CopyTensor(handle, - src_desc, - reserveSpace, - dst_desc, - hy, - src_batch_offset + RBuff.ht_relative_offset(), - hcy_layer_offset + hcy_batch_offset); - } + if(dirMode == 0u) + continue; - if(cy != nullptr) - { - CopyTensor(handle, - src_desc, - reserveSpace, - dst_desc, - cy, - src_batch_offset + RBuff.ct_relative_offset(), - hcy_layer_offset + hcy_batch_offset); - } - } - } + call_relu_tan_update_output(layer_id, time, RnnDirection::Backward); } - }; + } - auto call_sync_all_stream_pull_to_root_stream = [&stream_pull, root_stream_id]() { - const miopen::HipEventPtr main_event = make_hip_fast_event(); - hipEventRecord(main_event.get(), stream_pull[root_stream_id]); + // output tensor copy + { + const std::vector y_copy_size{ + 1, static_cast(total_batch_size), static_cast(out_vec_size)}; - for(int i = 0; i < stream_pull.size(); i++) - { - if(i != root_stream_id) - hipStreamWaitEvent(stream_pull[i], main_event.get(), 0); - } - }; + const std::vector y_src_stride{ + RBuff.layer_stride(), static_cast(RBuff.gemm_write_stride()), 1}; - auto sync_root_to_all_stream_pull = [&stream_pull, root_stream_id]() { - hipStream_t root_stream = stream_pull[root_stream_id]; - for(int i = 0; i < stream_pull.size(); i++) - { - if(i != root_stream_id) - { - const miopen::HipEventPtr sync_event = make_hip_fast_event(); - hipEventRecord(sync_event.get(), stream_pull[i]); - hipStreamWaitEvent(root_stream, sync_event.get(), 0); - } - } - }; + const std::vector y_dst_stride{static_cast(out_vec_size * total_batch_size), + static_cast(out_vec_size), + 1}; - if(seq_len == 0) - return; + auto src_desc = miopen::TensorDescriptor(rnn_data_type, y_copy_size, y_src_stride); + auto y_dst_desc = miopen::TensorDescriptor(rnn_data_type, y_copy_size, y_dst_stride); - const int try_chunks_cnt = 16; - const int time_chunk_sz = ((seq_len + try_chunks_cnt - 1) / try_chunks_cnt); - const int chunks_cnt = (seq_len + time_chunk_sz - 1) / time_chunk_sz; + CopyTensor(handle, + src_desc, + reserveSpace, + y_dst_desc, + y, + RBuff.hidden_offset(nLayers - 1, 0, RnnDirection::Forward), + 0); + } +#else + (void)handle; + (void)seq_array; + (void)xDesc; + (void)x; + (void)hxDesc; + (void)hx; + (void)cx; + (void)wDesc; + (void)w; + (void)yDesc; + (void)y; + (void)hy; + (void)cy; + (void)reserveSpace; + (void)reserveSpaceSize; - std::vector layer_inx_cur_time(nLayers, 0); - std::vector layer_hx_cur_time(nLayers, 0); - std::vector layer_upd_cur_time(nLayers, 0); + MIOPEN_THROW("GEMM is not supported"); +#endif +} - std::vector> layer_chunk_end_event; +void RNNDescriptor::RNNForwardTraining_MS(Handle& handle, + std::vector& seq_array, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + ConstData_t cx, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& yDesc, + Data_t y, + Data_t hy, + Data_t cy, + Data_t reserveSpace, + size_t reserveSpaceSize) const +{ +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + std::vector in_n; + int in_vec = xDesc.GetLengths()[1]; // input vector size + int out_vec = yDesc.GetLengths()[1]; // output vector size - layer_chunk_end_event.resize(nLayers); - for(int layer_id = 0; layer_id < nLayers; layer_id++) + int seq_len = seq_array.size(); + int max_batch = seq_array[0]; + int hidden_size; + + std::tie(std::ignore, max_batch, hidden_size) = miopen::tien<3>(hxDesc.GetLengths()); + + auto extra_stream_cnt = 2; + handle.ReserveExtraStreamsInPool(extra_stream_cnt); + + auto root_stream_id = 0; + std::vector stream_pull; + for(int i = 0; i <= extra_stream_cnt; i++) { - layer_chunk_end_event[layer_id].resize(chunks_cnt); - for(int chunk_id = 0; chunk_id < chunks_cnt; chunk_id++) - layer_chunk_end_event[layer_id][chunk_id] = make_hip_fast_event(); + handle.SetStreamFromPool(i); + stream_pull.push_back(handle.GetStream()); } - std::vector layer_stream_id(nLayers, 2); - layer_stream_id[0] = 1; + handle.SetStreamFromPool(root_stream_id); - auto call_inx_next_chunk_preload = [&](int layer_id) { - auto start_time = layer_inx_cur_time[layer_id]; - auto time_cnt = std::min(time_chunk_sz, seq_len - start_time); + int total_batch_size = 0; + std::vector bacc_per_time(seq_len + 1); - call_x_gemm(layer_id, start_time, time_cnt); - layer_inx_cur_time[layer_id] += time_chunk_sz; - }; + for(int i = 0; i < seq_len; i++) + { + bacc_per_time[i] = total_batch_size; + total_batch_size += seq_array[i]; + in_n.push_back(seq_array[i]); + } + bacc_per_time[seq_len] = total_batch_size; - auto call_hx_next_gemm = [&](int layer_id) { - auto cur_time = layer_hx_cur_time[layer_id]; - if(cur_time < seq_len) - { - call_hx_gemm(layer_id, cur_time); - layer_hx_cur_time[layer_id]++; - } - }; + const struct + { + int batch; + } InBuff_strides{in_vec}; - auto call_next_hidden_state_update = [&](int layer_id) { - auto cur_time = layer_upd_cur_time[layer_id]; - if(cur_time < seq_len) - { - call_hidden_state_update(layer_id, cur_time); - layer_upd_cur_time[layer_id]++; - } + auto get_HxBuff_offset = [&](int layer_id) { + return layer_id * (static_cast(hidden_size) * max_batch); }; - auto call_next_chunk_compute = [&handle, - &stream_pull, - &layer_stream_id, - &call_next_hidden_state_update, - &call_hx_next_gemm, - &call_inx_next_chunk_preload, - &layer_upd_cur_time, - &layer_chunk_end_event, - time_chunk_sz, - seq_len](int layer_id) { - auto stream_id = layer_stream_id[layer_id]; - handle.SetStreamFromPool(stream_id); - - const int chunk_id = layer_upd_cur_time[layer_id] / time_chunk_sz; - const int chunk_time = std::min(time_chunk_sz, seq_len - chunk_id * time_chunk_sz); + int gates_cnt = 4; + int save_points_cnt = 6; - if(layer_id > 0 && layer_stream_id[layer_id - 1] != stream_id) + struct WeightsBufferHelper + { + private: + auto hidden_xinput_size(int hidden_sz, int bidirect_mode) const { - hipStreamWaitEvent( - stream_pull[stream_id], layer_chunk_end_event[layer_id - 1][chunk_id].get(), 0); + if(bidirect_mode == 0) + return hidden_sz; + MIOPEN_THROW("execution failure: bidirect is not supported by this solver"); } - if(!(layer_id == 0 && chunk_id == 1)) + auto matrix_lin_layer_size(int input_vector_sz, int hidden_vec_sz, int gates) const { - call_inx_next_chunk_preload(layer_id); + return (input_vector_sz + hidden_vec_sz) * hidden_vec_sz * gates; } - - for(int time_id = 0; time_id < chunk_time; time_id++) + size_t bias_start_offset(int input_vector_sz, + int hidden_vec_sz, + int layers_cnt, + int gates, + int bidirect_mode) const { - call_hx_next_gemm(layer_id); - call_next_hidden_state_update(layer_id); + if(bidirect_mode == 0) + { + return matrix_lin_layer_size(input_vector_sz, hidden_vec_sz, gates) + + static_cast(hidden_vec_sz + hidden_xinput_size(hidden_vec_sz, 0)) * + hidden_vec_sz * static_cast(layers_cnt - 1) * gates; + } + + MIOPEN_THROW("execution failure: bidirect is not supported by this solver"); } - hipEventRecord(layer_chunk_end_event[layer_id][chunk_id].get(), stream_pull[stream_id]); - }; - { // reserveSpace clean set 0 - const int fill_val = 0; - // if(biasMode == 0u) req - hipMemsetAsync(reserveSpace, fill_val, reserveSpaceSize, handle.GetStream()); - } + public: + WeightsBufferHelper( + int input_vector_sz, int hidden_vec_sz, int layers_cnt, int bias_mode, int gates) + : in_vec(input_vector_sz), + h_vec(hidden_vec_sz), + x_in_vec(hidden_xinput_size(hidden_vec_sz, 0)), + layers(layers_cnt), + gates_cnt(gates), + bias_cnt(bias_mode), + matrix_normal_start_off(matrix_lin_layer_size(input_vector_sz, hidden_vec_sz, gates)), + bias_start_off( + bias_start_offset(input_vector_sz, hidden_vec_sz, layers_cnt, gates, 0)) + { + } - // stage 0 bias and input preload - // stage 0.2 first chunk compute and preload - { - call_sync_all_stream_pull_to_root_stream(); - const auto first_layer_id = 0; - const auto stream_id = layer_stream_id[first_layer_id]; // 1 - const auto extra_stream_id = 2; + const int in_vec, h_vec; + const int x_in_vec; // for bidirect TODO - handle.SetStreamFromPool(stream_id); + const int layers; + const int gates_cnt; + const int + bias_cnt; // 0 - no bias; 1 - one bias; 2 - separate bias for x_vec and for hidden_vec + private: + const size_t matrix_normal_start_off; + const size_t bias_start_off; - if(biasMode != 0u) - call_bias_add(first_layer_id); + public: + auto get_matrix_x_size(int layer_id) const + { + return (layer_id > 0 ? x_in_vec : in_vec) * h_vec; + } + auto get_matrix_h_size() const { return h_vec * h_vec; } + auto get_matrix_layer_size(int layer_id) const + { + return get_matrix_x_size(layer_id) * gates_cnt + get_matrix_h_size() * gates_cnt; + } - call_next_chunk_compute(first_layer_id); + size_t get_matrix_x_off(int layer_id) const + { + if(layer_id > 0) + { + return matrix_normal_start_off + + static_cast(layer_id - 1) * get_matrix_layer_size(layer_id); + } + else + { + return 0; + } + }; - handle.SetStreamFromPool(extra_stream_id); + size_t get_matrix_h_off(int layer_id) const + { + if(layer_id > 0) + { + return get_matrix_x_off(layer_id) + + static_cast(h_vec * x_in_vec * gates_cnt); + } + else + { + return get_matrix_x_off(layer_id) + static_cast(h_vec * in_vec) * gates_cnt; + } + }; - if(biasMode != 0u) + int bias_vector_size() const { return h_vec; } + int bias_vector_mul_gate() const { return bias_vector_size() * gates_cnt; } + int bias_stride() const { return bias_vector_mul_gate(); } + + size_t bias_relative_off(int layer_id, int bias_id) const { - for(int layer_id = 1; layer_id < nLayers; layer_id++) - call_bias_add(layer_id); + return static_cast(layer_id * bias_cnt + bias_id) * gates_cnt * h_vec; } - call_inx_next_chunk_preload(first_layer_id); + size_t get_bias_off(int layer_id, int bias_id) const + { + return bias_start_off + bias_relative_off(layer_id, bias_id); + } - // sync first to second stream - const miopen::HipEventPtr next_chunk_inx = make_hip_fast_event(); - hipEventRecord(next_chunk_inx.get(), stream_pull[extra_stream_id]); - hipStreamWaitEvent(stream_pull[stream_id], next_chunk_inx.get(), 0); - } + } WeiBuf(in_vec, hidden_size, nLayers, biasMode * 2, gates_cnt); - for(int layer_id = 0; layer_id < nLayers; layer_id++) + struct ReserveBufferHelper { + struct RBuffHelper + { + int element, save_point, batch; + size_t layer; + }; - const auto main_stream_id = 1; - handle.SetStreamFromPool(main_stream_id); + private: + auto Reserve_Buffer_strides(int save_point_sz, + int batches_per_layer, + int save_points, + int bidirect_mode = 0) const + { + const auto element_st = 1; + const auto save_point_st = element_st * save_point_sz; + const auto batch_st = save_point_st * save_points; + const auto layer_st = static_cast(batch_st) * batches_per_layer; + if(bidirect_mode == 0) + return RBuffHelper{element_st, save_point_st, batch_st, layer_st}; + MIOPEN_THROW("execution failure: bidirect is not supported by this solver"); + } - // check for wich stream was assigned this layer. If it differs from current - set stream - // wait event - if(layer_stream_id[layer_id] != main_stream_id) + public: + enum save_point { - auto chunk_id = layer_upd_cur_time[layer_id] / time_chunk_sz; - if(chunk_id > 0) - { - hipStreamWaitEvent(stream_pull[main_stream_id], - layer_chunk_end_event[layer_id][chunk_id - 1].get(), - 0); - } + F = 1, + I = 0, + G = 2, + O = 3, + St = 4, + Ht = 5 + }; - layer_stream_id[layer_id] = main_stream_id; + ReserveBufferHelper(int hidden_vec_sz, + int save_point_sz, + int layers_cnt, + int batches_per_layer, + int save_points, + int gates_cnt) + : h_vec(hidden_vec_sz), + save_point_size(save_point_sz), + layers(layers_cnt), + batches(batches_per_layer), + save_points_cnt(save_points), + gates(gates_cnt), + strides(Reserve_Buffer_strides(save_point_sz, batches, save_points, 0)) + { } - const int start_chunk = layer_upd_cur_time[layer_id] / time_chunk_sz; + const int h_vec; + const int save_point_size; // for bidirect TODO - const int extra_layer_max_chunks = - start_chunk + - ((layer_id + 1 < nLayers - 1) ? (chunks_cnt - start_chunk) / 2 : chunks_cnt); + const int layers; + const int batches; + const int save_points_cnt; + const int gates; + const RBuffHelper strides; - for(int chunk_id = start_chunk; chunk_id < chunks_cnt; chunk_id++) - { + size_t layer_offset(int layer) const { return static_cast(layer) * strides.layer; } + auto layer_stride() const { return strides.layer; } - call_next_chunk_compute(layer_id); + auto gemm_write_size() const { return h_vec * gates; } + auto gemm_write_stride() const + { + return strides.batch; + } // save_point_size * save_points_cnt - int extra_compute_layer = layer_id + 1; - for(; extra_compute_layer < nLayers; extra_compute_layer++) - { - auto extra_chunk_id = layer_upd_cur_time[extra_compute_layer] / time_chunk_sz; - if(extra_chunk_id < extra_layer_max_chunks && extra_chunk_id <= chunk_id) - break; - } + size_t gemm_write_relative_offset(int batch_id) const + { + return static_cast(gemm_write_stride()) * batch_id; + } - if(extra_compute_layer < nLayers) - call_next_chunk_compute(extra_compute_layer); + size_t gemm_write_offset(int layer, int batch_id) const + { + return layer_offset(layer) + static_cast(gemm_write_stride()) * batch_id; } - handle.SetStreamFromPool(main_stream_id); - // update hy, cy - call_hy_cy_update(layer_id); - } + auto ht_relative_offset() const { return save_point::Ht * save_point_size; } - handle.SetStreamFromPool(root_stream_id); - hipStreamWaitEvent( - stream_pull[root_stream_id], layer_chunk_end_event[nLayers - 1][chunks_cnt - 1].get(), 0); + auto ct_relative_offset() const { return save_point::St * save_point_size; } - // output tensor copy - { - const std::vector y_copy_size{ - 1, static_cast(total_batch_size), static_cast(out_vec)}; + auto get_gate_relative_offset(int gate_id) const { return gate_id * save_point_size; } - const std::vector y_src_stride{ - RBuff.layer_stride(), static_cast(RBuff.gemm_write_stride()), 1}; + size_t ht_offset(int layer_id, int batch_id) const + { + return layer_offset(layer_id) + gemm_write_relative_offset(batch_id) + + ht_relative_offset(); + } - const std::vector y_dst_stride{ - static_cast(out_vec * total_batch_size), static_cast(out_vec), 1}; + size_t extra_save_point_offset(int layer_id, int batch_id) const + { + return (static_cast(batches) * layers * gemm_write_stride()) // all data offset + + (static_cast(batches) * layer_id) * h_vec + + static_cast(batch_id * h_vec); + } - auto src_desc = miopen::TensorDescriptor(wDesc.GetType(), y_copy_size, y_src_stride); - auto y_dst_desc = miopen::TensorDescriptor(wDesc.GetType(), y_copy_size, y_dst_stride); + } RBuff(hidden_size, hidden_size, nLayers, total_batch_size, save_points_cnt, gates_cnt); - CopyTensor( - handle, src_desc, reserveSpace, y_dst_desc, y, RBuff.ht_offset(nLayers - 1, 0), 0); - } + auto call_x_gemm = [&RBuff, + &WeiBuf, + &InBuff_strides, + &bacc_per_time, + &handle, + &xDesc, + reserveSpace, + x, + w, + hidden_size, + in_vec](int layer, int start_time, int time_cnt, float beta_t = 1) { + const auto start_b = bacc_per_time[start_time]; + const auto batch_sz = bacc_per_time[start_time + time_cnt] - start_b; - sync_root_to_all_stream_pull(); -#else - (void)handle; - (void)seq_array; - (void)xDesc; - (void)x; - (void)hxDesc; - (void)hx; - (void)cx; - (void)wDesc; - (void)w; - (void)yDesc; - (void)y; - (void)hy; - (void)cy; - (void)reserveSpace; - (void)reserveSpaceSize; + const int m = batch_sz, n = RBuff.gemm_write_size(), k = layer > 0 ? hidden_size : in_vec; + const int lda = layer > 0 ? RBuff.gemm_write_stride() : InBuff_strides.batch, ldb = k, + ldc = RBuff.gemm_write_stride(); - MIOPEN_THROW("GEMM is not supported"); -#endif -} + const miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + m, + n, + k, + lda, + ldb, + ldc, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + beta_t, // beta + xDesc.GetType(), + false}; -// Assuming sequence length is set to > 0 otherwise throw exception. -void RNNDescriptor::RNNForwardInference(Handle& handle, - const int seqLen, - c_array_view xDesc, - ConstData_t x, - const TensorDescriptor& hxDesc, - ConstData_t hx, - const TensorDescriptor& cxDesc, - ConstData_t cx, - const TensorDescriptor& wDesc, - ConstData_t w, - c_array_view yDesc, - Data_t y, - const TensorDescriptor& hyDesc, - Data_t hy, - const TensorDescriptor& cyDesc, - Data_t cy, - Data_t workSpace, - size_t workSpaceSize) const -{ - if(x == nullptr || w == nullptr || y == nullptr) - { - MIOPEN_THROW(miopenStatusBadParm); - } - if(hxDesc.GetSize() != cxDesc.GetSize() || hxDesc.GetSize() != hyDesc.GetSize() || - hxDesc.GetSize() != cyDesc.GetSize()) - { - MIOPEN_THROW(miopenStatusBadParm); - } - if(workSpaceSize < GetWorkspaceSize(handle, seqLen, xDesc)) - { - MIOPEN_THROW("Workspace is required"); - } + const auto wx_off = WeiBuf.get_matrix_x_off(layer); + const auto out_offset = RBuff.gemm_write_offset(layer, start_b); -#if MIOPEN_BACKEND_HIP - HipEventPtr start = nullptr; - HipEventPtr stop = nullptr; - bool is_profiling = handle.IsProfilingEnabled(); + const auto x_in_offset = layer > 0 ? RBuff.ht_offset(layer - 1, start_b) + : static_cast(start_b * InBuff_strides.batch); + const auto in_ptr = layer > 0 ? reserveSpace : x; - if(is_profiling) - { - handle.EnableProfiling(false); - RNNProfilingBegin(handle, start, stop); - } - try - { -#endif + const miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + in_ptr, + x_in_offset, + w, + wx_off, + reserveSpace, + out_offset, + GemmBackend_t::rocblas); + if(gemm_status != miopenStatusSuccess) + MIOPEN_THROW("GEMM execution failure"); + }; - if(paddingMode == miopenRNNIONotPadded) - { - return RNNForwardInferencePacked(handle, - seqLen, - xDesc, - x, - hxDesc, - hx, - cxDesc, - cx, - wDesc, - w, - yDesc, - y, - hyDesc, - hy, - cyDesc, - cy, - workSpace, - workSpaceSize); - } - else - { - Data_t packedXIn = workSpace; - size_t packedXInSize, packedYOutSize; - std::tie(packedXInSize, packedYOutSize) = - RNNTensorPaddingConverter::GetTempPackedBuffersSpace(*this, xDesc); + auto call_bias_add = [&RBuff, &WeiBuf, &handle, &wDesc, reserveSpace, w](int layer, + float beta_t = 0) { + float alpha0 = 1; + float alpha1 = 1; + const auto bias_stride = WeiBuf.bias_stride(); - Data_t packedYOut = - static_cast(reinterpret_cast(workSpace) + packedXInSize); + const auto bias_desc = + miopen::TensorDescriptor(wDesc.GetType(), + std::vector{1, 1, WeiBuf.bias_vector_mul_gate()}, + std::vector{bias_stride, bias_stride, 1}); - auto shifted_workSpace = static_cast(reinterpret_cast(workSpace) + - (packedXInSize + packedYOutSize)); - auto shifted_workSpace_size = workSpaceSize - (packedXInSize + packedYOutSize); - // std::vector packed_desc; - // std::vector packed_desc_ptrs; - // RNNTensorPaddingConverter::CreatePackedDescriptor() - // for future developments: as long as we don't use strides from xDesc and yDesc - // we ignoring conversion of this descriptors. - std::vector in_n(seqLen); + const auto hidden_interim_desc = miopen::TensorDescriptor( + wDesc.GetType(), + std::vector{1, RBuff.batches, WeiBuf.bias_vector_mul_gate()}, + std::vector{ + RBuff.batches * RBuff.gemm_write_stride(), RBuff.gemm_write_stride(), 1}); - for(int i = 0; i < seqLen; i++) - { - int batchval, batchvalout; - std::tie(batchval, std::ignore) = miopen::tien<2>(xDesc[i].GetLengths()); - std::tie(batchvalout, std::ignore) = miopen::tien<2>(yDesc[i].GetLengths()); - if(batchval != batchvalout) - { - MIOPEN_THROW(miopenStatusBadParm, - "Input batch length: " + std::to_string(batchval) + - ", Output batch length: " + std::to_string(batchvalout)); - } - in_n[i] = batchval; - } + const auto RB_layer_out_off = RBuff.layer_offset(layer); + const auto w_bias_layer_start_off = WeiBuf.get_bias_off(layer, 0); - RNNTensorPaddingConverter::ConvertTensorData( - handle, xDesc[0], in_n, x, packedXIn, true); + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + hidden_interim_desc, + reserveSpace, // A + &alpha1, + bias_desc, + w, // B + &beta_t, + hidden_interim_desc, + reserveSpace, // C + RB_layer_out_off, // A offset + w_bias_layer_start_off, // B offset + RB_layer_out_off, // C offset + true); - RNNDescriptor packedRnnDesc(*this); - packedRnnDesc.SetPaddingmode(miopenRNNIONotPadded); + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + hidden_interim_desc, + reserveSpace, + &alpha1, + bias_desc, + w, + &beta_t, + hidden_interim_desc, + reserveSpace, + RB_layer_out_off, + w_bias_layer_start_off + bias_stride, + RB_layer_out_off, + true); + }; - packedRnnDesc.RNNForwardInferencePacked(handle, - seqLen, - xDesc, - packedXIn, - hxDesc, - hx, - cxDesc, - cx, - wDesc, - w, - yDesc, - packedYOut, - hyDesc, - hy, - cyDesc, - cy, - shifted_workSpace, - shifted_workSpace_size); + auto call_hx_gemm = [&RBuff, + &WeiBuf, + &get_HxBuff_offset, + &bacc_per_time, + &in_n, + &handle, + &xDesc, + reserveSpace, + hx, + w, + hidden_size](int layer, int cur_time) { + const int m = in_n.at(cur_time), n = RBuff.gemm_write_size(), k = hidden_size; - RNNTensorPaddingConverter::ConvertTensorData( - handle, yDesc[0], in_n, packedYOut, y, false); - } + const int lda = (cur_time != 0) ? RBuff.gemm_write_stride() : hidden_size, + ldb = hidden_size, ldc = RBuff.gemm_write_stride(); -#if MIOPEN_BACKEND_HIP - } - catch(...) - { - if(is_profiling) - handle.EnableProfiling(true); - throw; - } + const auto hx_ptr_offset = (cur_time == 0) + ? get_HxBuff_offset(layer) + : RBuff.ht_offset(layer, bacc_per_time[cur_time - 1]); - if(is_profiling) - { - float eventTime_mS = RNNProfilingEnd(handle, start, stop); - handle.EnableProfiling(true); - handle.ResetKernelTime(); - handle.AccumKernelTime(eventTime_mS); - } -#endif -} -void RNNDescriptor::RNNForwardInferencePacked(Handle& handle, - const int seqLen, - c_array_view xDesc, - ConstData_t x, - const TensorDescriptor& hxDesc, - ConstData_t hx, - const TensorDescriptor& cxDesc, - ConstData_t cx, - const TensorDescriptor& wDesc, - ConstData_t w, - c_array_view yDesc, - Data_t y, - const TensorDescriptor& hyDesc, - Data_t hy, - const TensorDescriptor& cyDesc, - Data_t cy, - Data_t workSpace, - size_t workSpaceSize) const -{ - (void)cyDesc; - (void)hxDesc; - (void)cxDesc; + if(cur_time == 0) + { + if(hx == nullptr) + return; + } -#if MIOPEN_USE_GEMM + const miopen::GemmDescriptor gemm_desc_hx = GemmDescriptor{false, + false, + true, + m, + n, + k, + lda, + ldb, + ldc, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc.GetType(), + false}; - float ctime = 0.; - // reset kernel timer - profileRNNkernels(handle, 0, ctime); + const auto RB_layer_save_points_off = + RBuff.gemm_write_offset(layer, bacc_per_time[cur_time]); - std::vector in_n; - int in_h = xDesc[0].GetLengths()[1]; // input vector size - int hy_d = hyDesc.GetLengths()[0]; // biNumLayers - int hy_n = hyDesc.GetLengths()[1]; // max batch size - int hy_h = hyDesc.GetLengths()[2]; // hidden size - int out_h = yDesc[0].GetLengths()[1]; // output vector size - int bi = dirMode != 0u ? 2 : 1; + const auto hx_ptr = cur_time > 0 ? reserveSpace : hx; - if(in_h <= 0 || hy_h <= 0 || hy_n <= 0 || hy_d <= 0 || out_h <= 0 || seqLen <= 0) - { - MIOPEN_THROW(miopenStatusBadParm); - } + const miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc_hx, + hx_ptr, + hx_ptr_offset, + w, + WeiBuf.get_matrix_h_off(layer), + reserveSpace, + RB_layer_save_points_off, + GemmBackend_t::rocblas); - if(out_h != (bi * hy_h)) - { - MIOPEN_THROW(miopenStatusBadParm, "Output size doesn't match hidden state size!"); - } + if(gemm_status != miopenStatusSuccess) + MIOPEN_THROW("GEMM execution failure"); + }; - if(inputMode == miopenRNNskip) - { - if(in_h != hy_h) - { - MIOPEN_THROW(miopenStatusBadParm, - "The input tensor size must equal to the hidden " - "state size of the network in SKIP_INPUT mode!"); - } - in_h = 0; - } + auto call_hidden_state_update = [&RBuff, + &get_HxBuff_offset, + &bacc_per_time, + &in_n, + &handle, + &wDesc, + reserveSpace, + cx, + max_batch, + hidden_size](int layer_id, int time_id) { + auto RB_layer_save_points_off = + RBuff.layer_offset(layer_id) + RBuff.gemm_write_relative_offset(bacc_per_time[time_id]); - int batch_n = 0; - for(int i = 0; i < seqLen; i++) - { - int batchval, inputvec, batchvalout, outputvec; - std::tie(batchval, inputvec) = miopen::tien<2>(xDesc[i].GetLengths()); - std::tie(batchvalout, outputvec) = miopen::tien<2>(yDesc[i].GetLengths()); - if(batchval != batchvalout) - { - MIOPEN_THROW(miopenStatusBadParm, - "Input batch length: " + std::to_string(batchval) + - ", Output batch length: " + std::to_string(batchvalout)); - } - if(i == 0) - { - if(batchval <= 0) - { - MIOPEN_THROW(miopenStatusBadParm, "Input batch is ZERO!"); - } - } - else - { - if(batchval > in_n.back() || batchval < 0) - { - MIOPEN_THROW(miopenStatusBadParm, - "Incorrect input batch size at time " + std::to_string(i) + - "! Batch size must not ascend!"); - } - } - in_n.push_back(batchval); - batch_n += batchval; - } - // input check end + auto is_seq_begin = time_id == 0; - int in_stride = xDesc[0].GetLengths()[1]; - int hy_stride = hy_h * bi * static_cast(workspaceScale); - int out_stride = out_h; - int wei_stride = hy_h * bi * static_cast(nHiddenTensorsPerLayer); - int uni_stride = hy_h; - int bi_stride = hy_h * bi; + const int direction = 0; + const int cur_batch = in_n.at(time_id), use_batch = in_n.at(time_id); - size_t wei_shift_bias = (in_h + hy_h + (bi * hy_h + hy_h) * (nLayers - 1)) * wei_stride; - size_t offset; - float alpha0, alpha1, beta_t; - float alpha = 1, beta = 0; + const int hy_stride = RBuff.gemm_write_stride(), wei_len = RBuff.gemm_write_size(), + wei_stride = RBuff.gemm_write_size(); - std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1), x_size(3, 1), - x_stride(3, 1), y_size(3, 1), y_stride(3, 1), hx_size(3, 1), hx_stride(3, 1); - miopen::TensorDescriptor sp_desc, w_desc, x_desc, y_desc, hx_desc; + const size_t cx_offset = get_HxBuff_offset(layer_id); - sp_size[2] = workSpaceSize / GetTypeSize(wDesc.GetType()); - sp_stride[0] = sp_size[2]; - sp_stride[1] = sp_size[2]; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - SetTensor(handle, sp_desc, workSpace, &beta); - // Update time - profileRNNkernels(handle, 1, ctime); - sp_stride[0] = batch_n * hy_stride; - sp_stride[1] = hy_stride; - sp_size[2] = 1; - w_stride[0] = wei_stride; - w_stride[1] = wei_stride; - x_stride[0] = batch_n * in_stride; - x_stride[1] = in_stride; - y_stride[0] = batch_n * out_stride; - y_stride[1] = out_stride; - if(hy != nullptr || (rnnMode == miopenLSTM && cy != nullptr)) - { - hx_size[2] = hy_d * hy_n * hy_h; - hx_stride[0] = hx_size[2]; - hx_stride[1] = hx_size[2]; - hx_desc = miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); - if(hy != nullptr) - { - SetTensor(handle, hx_desc, hy, &beta); - // Update time - profileRNNkernels(handle, 1, ctime); - } - if(rnnMode == miopenLSTM && cy != nullptr) - { - SetTensor(handle, hx_desc, cy, &beta); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - hx_stride[0] = in_n.at(0) * uni_stride; - hx_stride[1] = uni_stride; + const size_t i_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(0), + f_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(1), + o_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(2), + c_offset = RB_layer_save_points_off + RBuff.get_gate_relative_offset(3); - int wei_shift, prelayer_shift; - int wei_len = 0; - int hid_off = 0; + const size_t cell_offset = RB_layer_save_points_off + RBuff.ct_relative_offset(), + hidden_offset = RB_layer_save_points_off + RBuff.ht_relative_offset(); - switch(rnnMode) - { - case miopenRNNRELU: - case miopenRNNTANH: - // printf("run rnn gpu inference \n"); - wei_len = hy_h; - hid_off = 0; - break; - case miopenLSTM: - // printf("run lstm gpu inference \n"); - wei_len = hy_h * 4; - hid_off = bi * hy_h * 5; - break; - case miopenGRU: - // printf("run gru gpu inference \n"); - wei_len = hy_h * 3; - hid_off = bi * hy_h * 3; - break; - } + const size_t cell_offset_pre = + (time_id == 0) ? 0 + : RBuff.layer_offset(layer_id) + + RBuff.gemm_write_relative_offset(bacc_per_time[time_id - 1]) + + RBuff.ct_relative_offset(); - ActivationDescriptor tanhDesc, sigDesc, activDesc; - sigDesc = {miopenActivationLOGISTIC, 1, 0, 1}; - tanhDesc = {miopenActivationTANH, 1, 1, 1}; - if(rnnMode == miopenRNNRELU) - { - activDesc = {miopenActivationRELU, 1, 0, 1}; - } - else if(rnnMode == miopenRNNTANH) - { - activDesc = {miopenActivationTANH, 1, 1, 1}; - } + const size_t activ_cell_offset = + RBuff.extra_save_point_offset(layer_id, bacc_per_time[time_id]); - for(int li = 0; li < nLayers; li++) - { - int hid_shift = li * batch_n * hy_stride; - int hx_shift = li * hy_n * bi_stride; - int wei_shift_bias_temp = static_cast(wei_shift_bias) + li * 2 * wei_stride; + LSTMForwardHiddenStateUpdate(handle, + wDesc.GetType(), + false, + is_seq_begin, + direction, + max_batch, + cur_batch, + use_batch, - // from input - if(li == 0) - { - if(inputMode == miopenRNNskip) - { - x_size[1] = batch_n; - x_size[2] = hy_h; - sp_size[1] = batch_n; - sp_size[2] = hy_h; - x_desc = miopen::TensorDescriptor(wDesc.GetType(), x_size, x_stride); - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + hidden_size, + hy_stride, + wei_len, + wei_stride, + cx, + cx_offset, + reserveSpace, + i_offset, + f_offset, + o_offset, + c_offset, + cell_offset, + cell_offset_pre, + activ_cell_offset, + hidden_offset); + }; - for(int gi = 0; gi < nHiddenTensorsPerLayer * bi; gi++) - { - CopyTensor(handle, x_desc, x, sp_desc, workSpace, 0, gi * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - miopen::GemmDescriptor gemm_desc = - GemmDescriptor{false, - false, - true, - batch_n, - wei_len * bi, - in_h, - in_stride, - in_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; // RNN does not support determinism + auto call_hy_cy_update = [&RBuff, + &get_HxBuff_offset, + &bacc_per_time, + &in_n, + &handle, + &wDesc, + reserveSpace, + hy, + cy, + max_batch, + hidden_size, + seq_len](int layer_id) { + if(hy != nullptr || (cy != nullptr)) + { + auto hcy_layer_offset = get_HxBuff_offset(layer_id); - miopenStatus_t gemm_status = CallGemm( - handle, gemm_desc, x, 0, w, 0, workSpace, hid_shift, GemmBackend_t::rocblas); + const std::vector hcy_src_stride{ + RBuff.layer_stride(), static_cast(RBuff.gemm_write_stride()), 1}; + const std::vector hcy_dst_stride{ + static_cast(hidden_size * max_batch), static_cast(hidden_size), 1}; - if(gemm_status != miopenStatusSuccess) + for(int time_i = seq_len - 1; time_i >= 0; time_i--) + { + auto copy_batch = (time_i == seq_len - 1) ? in_n.at(time_i) + : in_n.at(time_i) - in_n.at(time_i + 1); + if(copy_batch > 0) { - if(gemm_status == miopenStatusNotImplemented) + auto batch_id_relative = in_n.at(time_i) - copy_batch; + auto batch_id_abs = bacc_per_time[time_i] + batch_id_relative; + + auto hcy_batch_offset = batch_id_relative * hidden_size; + + auto src_batch_offset = RBuff.layer_offset(layer_id) + + RBuff.gemm_write_relative_offset(batch_id_abs); + + const std::vector hcy_copy_size{ + 1, static_cast(copy_batch), static_cast(hidden_size)}; + + auto src_desc = + miopen::TensorDescriptor(wDesc.GetType(), hcy_copy_size, hcy_src_stride); + auto dst_desc = + miopen::TensorDescriptor(wDesc.GetType(), hcy_copy_size, hcy_dst_stride); + + if(hy != nullptr) { - MIOPEN_LOG_E("GEMM not implemented"); + CopyTensor(handle, + src_desc, + reserveSpace, + dst_desc, + hy, + src_batch_offset + RBuff.ht_relative_offset(), + hcy_layer_offset + hcy_batch_offset); } - else + + if(cy != nullptr) { - MIOPEN_LOG_E("GEMM failed"); + CopyTensor(handle, + src_desc, + reserveSpace, + dst_desc, + cy, + src_batch_offset + RBuff.ct_relative_offset(), + hcy_layer_offset + hcy_batch_offset); } } - // Update time - profileRNNkernels(handle, 1, ctime); } } - else - { - wei_shift = (in_h + hy_h) * wei_stride + (li - 1) * (bi * hy_h + hy_h) * wei_stride; - prelayer_shift = (li - 1) * batch_n * hy_stride + hid_off; - - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - batch_n, - wei_len * bi, - hy_h * bi, - hy_stride, - bi_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; - miopenStatus_t gemm_status = CallGemm(handle, - gemm_desc, - workSpace, - prelayer_shift, - w, - wei_shift, - workSpace, - hid_shift, - GemmBackend_t::rocblas); + }; - if(gemm_status != miopenStatusSuccess) - { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else - { - MIOPEN_LOG_E("GEMM failed"); - } - } - // Update time - profileRNNkernels(handle, 1, ctime); - } + auto call_sync_all_stream_pull_to_root_stream = [&stream_pull, root_stream_id]() { + const miopen::HipEventPtr main_event = make_hip_fast_event(); + hipEventRecord(main_event.get(), stream_pull[root_stream_id]); - if(biasMode != 0u) + for(int i = 0; i < stream_pull.size(); i++) { - alpha0 = 1; - alpha1 = 1; - beta_t = 0; - - w_size[1] = 1; - w_size[2] = wei_stride; - sp_size[1] = batch_n; - sp_size[2] = wei_stride; - w_desc = miopen::TensorDescriptor(wDesc.GetType(), w_size, w_stride); - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - workSpace, - hid_shift, - wei_shift_bias_temp, - hid_shift, - true); - // Update time - profileRNNkernels(handle, 1, ctime); + if(i != root_stream_id) + hipStreamWaitEvent(stream_pull[i], main_event.get(), 0); } + }; - if(rnnMode == miopenGRU) + auto sync_root_to_all_stream_pull = [&stream_pull, root_stream_id]() { + hipStream_t root_stream = stream_pull[root_stream_id]; + for(int i = 0; i < stream_pull.size(); i++) { - sp_size[1] = batch_n; - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - - alpha0 = 0; - alpha1 = 0; - beta_t = 0; - for(int bs = 0; bs < bi; bs++) + if(i != root_stream_id) { - CopyTensor(handle, - sp_desc, - workSpace, - sp_desc, - workSpace, - hid_shift + bs * wei_len + 2 * hy_h, - hid_shift + hid_off + bs * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - hid_shift + bs * wei_len + 2 * hy_h, - hid_shift + bs * wei_len + 2 * hy_h, - hid_shift + bs * wei_len + 2 * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + const miopen::HipEventPtr sync_event = make_hip_fast_event(); + hipEventRecord(sync_event.get(), stream_pull[i]); + hipStreamWaitEvent(root_stream, sync_event.get(), 0); } } + }; - if(biasMode != 0u) - { - wei_shift_bias_temp += wei_stride; + if(seq_len == 0) + return; - alpha0 = 1; - alpha1 = 1; - beta_t = 0; + const int try_chunks_cnt = 16; + const int time_chunk_sz = ((seq_len + try_chunks_cnt - 1) / try_chunks_cnt); + const int chunks_cnt = (seq_len + time_chunk_sz - 1) / time_chunk_sz; - if(hx != nullptr) - { - sp_size[1] = batch_n; - sp_size[2] = wei_stride; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + std::vector layer_inx_cur_time(nLayers, 0); + std::vector layer_hx_cur_time(nLayers, 0); + std::vector layer_upd_cur_time(nLayers, 0); - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - workSpace, - hid_shift, - wei_shift_bias_temp, - hid_shift, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - else - { - sp_size[1] = batch_n - in_n.at(0); - sp_size[2] = wei_len; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - w_size[1] = 1; - w_size[2] = wei_len; - w_desc = miopen::TensorDescriptor(wDesc.GetType(), w_size, w_stride); + std::vector> layer_chunk_end_event; - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - workSpace, - hid_shift + in_n.at(0) * hy_stride, - wei_shift_bias_temp, - hid_shift + in_n.at(0) * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + layer_chunk_end_event.resize(nLayers); + for(int layer_id = 0; layer_id < nLayers; layer_id++) + { + layer_chunk_end_event[layer_id].resize(chunks_cnt); + for(int chunk_id = 0; chunk_id < chunks_cnt; chunk_id++) + layer_chunk_end_event[layer_id][chunk_id] = make_hip_fast_event(); + } - if(dirMode != 0u) - { - if(in_n.at(0) == in_n.at(seqLen - 1)) - { - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - workSpace, - hid_shift + wei_len, - wei_shift_bias_temp + wei_len, - hid_shift + wei_len, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - else - { - int cur_batch = 0; - for(int ti = 0; ti < seqLen; ti++) - { - if(ti != (seqLen - 1)) - { - offset = hid_shift + cur_batch * hy_stride; + std::vector layer_stream_id(nLayers, 2); + layer_stream_id[0] = 1; - sp_size[1] = in_n.at(ti + 1); - sp_size[2] = wei_len; - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + auto call_inx_next_chunk_preload = [&](int layer_id) { + auto start_time = layer_inx_cur_time[layer_id]; + auto time_cnt = std::min(time_chunk_sz, seq_len - start_time); - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - workSpace, - offset + wei_len, - wei_shift_bias_temp + wei_len, - offset + wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); - } - cur_batch += in_n.at(ti); - } - } - } - } + call_x_gemm(layer_id, start_time, time_cnt); + layer_inx_cur_time[layer_id] += time_chunk_sz; + }; + + auto call_hx_next_gemm = [&](int layer_id) { + auto cur_time = layer_hx_cur_time[layer_id]; + if(cur_time < seq_len) + { + call_hx_gemm(layer_id, cur_time); + layer_hx_cur_time[layer_id]++; } + }; - // from hidden state - int bacc = 0; - int baccbi = batch_n; - for(int ti = 0; ti < seqLen; ti++) + auto call_next_hidden_state_update = [&](int layer_id) { + auto cur_time = layer_upd_cur_time[layer_id]; + if(cur_time < seq_len) { - baccbi -= in_n.at(seqLen - 1 - ti); - wei_shift = in_h * wei_stride + li * (bi * hy_h + hy_h) * wei_stride; - int pretime_shift = 0; - int use_time = 0; + call_hidden_state_update(layer_id, cur_time); + layer_upd_cur_time[layer_id]++; + } + }; - for(int ri = 0; ri < bi; ri++) - { - int cur_time = ri == 0 ? ti : seqLen - 1 - ti; - int cur_batch = ri == 0 ? bacc : baccbi; - offset = hid_shift + cur_batch * hy_stride; - if(ti > 0) - { - pretime_shift = - ri == 0 ? hid_shift + (bacc - in_n.at(ti - 1)) * hy_stride - : hid_shift + (baccbi + in_n.at(seqLen - 1 - ti)) * hy_stride; - use_time = ri == 0 ? ti : seqLen - ti; - } + auto call_next_chunk_compute = [&handle, + &stream_pull, + &layer_stream_id, + &call_next_hidden_state_update, + &call_hx_next_gemm, + &call_inx_next_chunk_preload, + &layer_upd_cur_time, + &layer_chunk_end_event, + time_chunk_sz, + seq_len](int layer_id) { + auto stream_id = layer_stream_id[layer_id]; + handle.SetStreamFromPool(stream_id); - if(in_n.at(cur_time) > 0) - { - if(ti == 0) - { - if(hx != nullptr) - { - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - in_n.at(cur_time), - wei_len, - hy_h, - uni_stride, - uni_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; + const int chunk_id = layer_upd_cur_time[layer_id] / time_chunk_sz; + const int chunk_time = std::min(time_chunk_sz, seq_len - chunk_id * time_chunk_sz); - miopenStatus_t gemm_status = - CallGemm(handle, - gemm_desc, - hx, - hx_shift + ri * hy_n * hy_h, - w, - wei_shift + ri * wei_len * uni_stride, - workSpace, - static_cast(offset) + ri * wei_len, - GemmBackend_t::rocblas); + if(layer_id > 0 && layer_stream_id[layer_id - 1] != stream_id) + { + hipStreamWaitEvent( + stream_pull[stream_id], layer_chunk_end_event[layer_id - 1][chunk_id].get(), 0); + } - if(gemm_status != miopenStatusSuccess) - { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else - { - MIOPEN_LOG_E("GEMM failed"); - } - } - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) - { - miopen::GemmDescriptor gemm_desc = - GemmDescriptor{false, - false, - true, - (in_n.at(cur_time) - in_n.at(use_time)), - wei_len, - hy_h, - uni_stride, - uni_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; - miopenStatus_t gemm_status = - CallGemm(handle, - gemm_desc, - hx, - hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, - w, - wei_shift + ri * wei_len * uni_stride, - workSpace, - static_cast(offset) + ri * wei_len + - in_n.at(use_time) * hy_stride, - GemmBackend_t::rocblas); + if(!(layer_id == 0 && chunk_id == 1)) + { + call_inx_next_chunk_preload(layer_id); + } - if(gemm_status != miopenStatusSuccess) - { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else - { - MIOPEN_LOG_E("GEMM failed"); - } - } - // Update time - profileRNNkernels(handle, 1, ctime); - } + for(int time_id = 0; time_id < chunk_time; time_id++) + { + call_hx_next_gemm(layer_id); + call_next_hidden_state_update(layer_id); + } + hipEventRecord(layer_chunk_end_event[layer_id][chunk_id].get(), stream_pull[stream_id]); + }; - if(in_n.at(use_time) > 0) - { - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - in_n.at(use_time), - wei_len, - hy_h, - hy_stride, - uni_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; + { // reserveSpace clean set 0 + const int fill_val = 0; + // if(biasMode == 0u) req + hipMemsetAsync(reserveSpace, fill_val, reserveSpaceSize, handle.GetStream()); + } - miopenStatus_t gemm_status = - CallGemm(handle, - gemm_desc, - workSpace, - pretime_shift + hid_off + ri * hy_h, - w, - wei_shift + ri * wei_len * uni_stride, - workSpace, - static_cast(offset) + ri * wei_len, - GemmBackend_t::rocblas); + // stage 0 bias and input preload + // stage 0.2 first chunk compute and preload + { + call_sync_all_stream_pull_to_root_stream(); + const auto first_layer_id = 0; + const auto stream_id = layer_stream_id[first_layer_id]; // 1 + const auto extra_stream_id = 2; - if(gemm_status != miopenStatusSuccess) - { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else - { - MIOPEN_LOG_E("GEMM failed"); - } - } - // Update time - profileRNNkernels(handle, 1, ctime); - } - } + handle.SetStreamFromPool(stream_id); - // update hidden status - sp_size[1] = in_n.at(cur_time); - if(rnnMode == miopenRNNRELU || rnnMode == miopenRNNTANH) - { - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + if(biasMode != 0u) + call_bias_add(first_layer_id); - activDesc.Forward(handle, - &alpha, - sp_desc, - workSpace, - &beta, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); - } - else if(rnnMode == miopenLSTM) - { - if(algoMode == miopenRNNdefault) - { - LSTMForwardHiddenStateUpdate( - handle, - wDesc.GetType(), - true, - ti == 0, - ri, - in_n.at(0), - in_n.at(cur_time), - in_n.at(use_time), - hy_h, - hy_stride, - wei_len, - wei_stride, - cx, - hx_shift + ri * hy_n * hy_h, - workSpace, - offset + static_cast(ri) * wei_len, - offset + hy_h + static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - pretime_shift + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - 0, - offset + hid_off + static_cast(ri) * hy_h); + call_next_chunk_compute(first_layer_id); - // Update time - profileRNNkernels(handle, 1, ctime); - continue; - } + handle.SetStreamFromPool(extra_stream_id); - // active gate i, f, o - sp_size[2] = hy_h * 3; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + if(biasMode != 0u) + { + for(int layer_id = 1; layer_id < nLayers; layer_id++) + call_bias_add(layer_id); + } - sigDesc.Forward(handle, - &alpha, - sp_desc, - workSpace, - &beta, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + call_inx_next_chunk_preload(first_layer_id); - // active gate c - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + // sync first to second stream + const miopen::HipEventPtr next_chunk_inx = make_hip_fast_event(); + hipEventRecord(next_chunk_inx.get(), stream_pull[extra_stream_id]); + hipStreamWaitEvent(stream_pull[stream_id], next_chunk_inx.get(), 0); + } - tanhDesc.Forward(handle, - &alpha, - sp_desc, - workSpace, - &beta, - sp_desc, - workSpace, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + for(int layer_id = 0; layer_id < nLayers; layer_id++) + { - // update cell state - alpha0 = 1; - alpha1 = 1; - beta_t = 1; + const auto main_stream_id = 1; + handle.SetStreamFromPool(main_stream_id); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + // check for wich stream was assigned this layer. If it differs from current - set stream + // wait event + if(layer_stream_id[layer_id] != main_stream_id) + { + auto chunk_id = layer_upd_cur_time[layer_id] / time_chunk_sz; + if(chunk_id > 0) + { + hipStreamWaitEvent(stream_pull[main_stream_id], + layer_chunk_end_event[layer_id][chunk_id - 1].get(), + 0); + } - if(ti == 0) - { - if(cx != nullptr) - { - hx_size[1] = in_n.at(cur_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + layer_stream_id[layer_id] = main_stream_id; + } - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - hx_desc, - cx, - &beta_t, - sp_desc, - workSpace, - offset + hy_h + static_cast(ri) * wei_len, - hx_shift + ri * hy_n * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - if(ri == 1 && cx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) - { - hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + const int start_chunk = layer_upd_cur_time[layer_id] / time_chunk_sz; - sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + const int extra_layer_max_chunks = + start_chunk + + ((layer_id + 1 < nLayers - 1) ? (chunks_cnt - start_chunk) / 2 : chunks_cnt); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - hx_desc, - cx, - &beta_t, - sp_desc, - workSpace, - offset + hy_h + static_cast(ri) * wei_len + - static_cast(in_n.at(use_time)) * hy_stride, - hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h + - static_cast(in_n.at(use_time)) * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); - - sp_size[1] = in_n.at(cur_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - } - - if(in_n.at(use_time) > 0) - { - if(in_n.at(use_time) != in_n.at(cur_time)) - { - sp_size[1] = in_n.at(use_time); - sp_desc = miopen::TensorDescriptor( - wDesc.GetType(), sp_size, sp_stride); - } + for(int chunk_id = start_chunk; chunk_id < chunks_cnt; chunk_id++) + { - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + hy_h + static_cast(ri) * wei_len, - pretime_shift + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + call_next_chunk_compute(layer_id); - if(in_n.at(use_time) != in_n.at(cur_time)) - { - sp_size[1] = in_n.at(cur_time); - sp_desc = miopen::TensorDescriptor( - wDesc.GetType(), sp_size, sp_stride); - } - } - } + int extra_compute_layer = layer_id + 1; + for(; extra_compute_layer < nLayers; extra_compute_layer++) + { + auto extra_chunk_id = layer_upd_cur_time[extra_compute_layer] / time_chunk_sz; + if(extra_chunk_id < extra_layer_max_chunks && extra_chunk_id <= chunk_id) + break; + } - // active cell state - tanhDesc.Forward(handle, - &alpha, - sp_desc, - workSpace, - &beta, - sp_desc, - workSpace, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + if(extra_compute_layer < nLayers) + call_next_chunk_compute(extra_compute_layer); + } - // update hidden state - beta_t = 0; - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + hid_off + static_cast(ri) * hy_h, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - else if(rnnMode == miopenGRU) - { - // active z, r gate - sp_size[2] = 2 * hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + handle.SetStreamFromPool(main_stream_id); + // update hy, cy + call_hy_cy_update(layer_id); + } - sigDesc.Forward(handle, - &alpha, - sp_desc, - workSpace, - &beta, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + handle.SetStreamFromPool(root_stream_id); + hipStreamWaitEvent( + stream_pull[root_stream_id], layer_chunk_end_event[nLayers - 1][chunks_cnt - 1].get(), 0); - // calculate c gate - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + // output tensor copy + { + const std::vector y_copy_size{ + 1, static_cast(total_batch_size), static_cast(out_vec)}; - alpha0 = 1; - alpha1 = 1; - beta_t = 0; + const std::vector y_src_stride{ + RBuff.layer_stride(), static_cast(RBuff.gemm_write_stride()), 1}; - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + hy_h + static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + const std::vector y_dst_stride{ + static_cast(out_vec * total_batch_size), static_cast(out_vec), 1}; - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + hid_off + static_cast(ri) * hy_h, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + auto src_desc = miopen::TensorDescriptor(wDesc.GetType(), y_copy_size, y_src_stride); + auto y_dst_desc = miopen::TensorDescriptor(wDesc.GetType(), y_copy_size, y_dst_stride); - // active c gate - tanhDesc.Forward(handle, - &alpha, - sp_desc, - workSpace, - &beta, - sp_desc, - workSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + CopyTensor( + handle, src_desc, reserveSpace, y_dst_desc, y, RBuff.ht_offset(nLayers - 1, 0), 0); + } - // calculate hidden state - alpha0 = -1; - alpha1 = 1; - beta_t = 0; - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + sync_root_to_all_stream_pull(); +#else + (void)handle; + (void)seq_array; + (void)xDesc; + (void)x; + (void)hxDesc; + (void)hx; + (void)cx; + (void)wDesc; + (void)w; + (void)yDesc; + (void)y; + (void)hy; + (void)cy; + (void)reserveSpace; + (void)reserveSpaceSize; - alpha0 = 1; - alpha1 = 1; - beta_t = 0; + MIOPEN_THROW("GEMM is not supported"); +#endif +} - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + hid_off + static_cast(ri) * hy_h, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - - alpha0 = 1; - alpha1 = 1; - beta_t = 1; - if(ti == 0) - { - if(hx != nullptr) - { - hx_size[1] = in_n.at(cur_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); - - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - hx_desc, - hx, - &beta_t, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - hx_shift + ri * hy_n * hy_h, - offset + hid_off + static_cast(ri) * hy_h, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) - { - hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); - - sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - hx_desc, - hx, - &beta_t, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len + - static_cast(in_n.at(use_time)) * hy_stride, - hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, - offset + hid_off + static_cast(ri) * hy_h + - static_cast(in_n.at(use_time)) * hy_stride, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - - sp_size[1] = in_n.at(cur_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - } - - if(in_n.at(use_time) > 0) - { - if(in_n.at(use_time) != in_n.at(cur_time)) - { - sp_size[1] = in_n.at(use_time); - sp_desc = miopen::TensorDescriptor( - wDesc.GetType(), sp_size, sp_stride); - } - - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - workSpace, - &alpha1, - sp_desc, - workSpace, - &beta_t, - sp_desc, - workSpace, - offset + static_cast(ri) * wei_len, - pretime_shift + hid_off + ri * hy_h, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - } - } - } - - bacc += in_n.at(ti); - } - - // update hy, cy - if(hy != nullptr || (rnnMode == miopenLSTM && cy != nullptr)) - { - hx_size[2] = hy_h; - sp_size[2] = hy_h; - - bacc = batch_n; - baccbi = 0; - for(int ti = seqLen - 1; ti >= 0; ti--) - { - bacc -= in_n.at(ti); - for(int ri = 0; ri < bi; ri++) - { - int cur_time = ri == 0 ? ti : seqLen - 1 - ti; - int cur_batch = ri == 0 ? bacc : baccbi; - int use_batch = 0; - - if(ti < seqLen - 1) - { - int use_time = ri == 0 ? ti + 1 : seqLen - 2 - ti; - use_batch = in_n.at(use_time); - } - - if(in_n.at(cur_time) > use_batch) - { - offset = hid_shift + cur_batch * hy_stride; - - sp_size[1] = in_n.at(cur_time) - use_batch; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - - hx_size[1] = sp_size[1]; - hx_desc = miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); - - if(hy != nullptr) - { - CopyTensor(handle, - sp_desc, - workSpace, - hx_desc, - hy, - static_cast(offset) + hid_off + ri * hy_h + - use_batch * hy_stride, - hx_shift + ri * hy_n * hy_h + use_batch * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - - if(rnnMode == miopenLSTM && cy != nullptr) - { - CopyTensor(handle, - sp_desc, - workSpace, - hx_desc, - cy, - static_cast(offset) + bi * wei_len + ri * hy_h + - use_batch * hy_stride, - hx_shift + ri * hy_n * hy_h + use_batch * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - } - baccbi += in_n.at(seqLen - 1 - ti); - } - } - } - - // output - prelayer_shift = (static_cast(nLayers) - 1) * batch_n * hy_stride + hid_off; - - sp_size[1] = batch_n; - sp_size[2] = hy_h * bi; - y_size[1] = batch_n; - y_size[2] = out_h; - y_desc = miopen::TensorDescriptor(wDesc.GetType(), y_size, y_stride); - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - - CopyTensor(handle, sp_desc, workSpace, y_desc, y, prelayer_shift, 0); - // Update time - profileRNNkernels(handle, 2, ctime); - -#else - (void)hx; - (void)cx; - (void)handle; - (void)seqLen; - (void)xDesc; - (void)x; - (void)w; - (void)y; - (void)hyDesc; - (void)hy; - (void)yDesc; - (void)wDesc; - (void)workSpaceSize; - (void)workSpace; - MIOPEN_THROW("GEMM is not supported"); -#endif -} - -void RNNDescriptor::RNNForwardTraining(Handle& handle, - const int seqLen, - c_array_view xDesc, - ConstData_t x, - const TensorDescriptor& hxDesc, - ConstData_t hx, - const TensorDescriptor& cxDesc, - ConstData_t cx, - const TensorDescriptor& wDesc, - ConstData_t w, - c_array_view yDesc, - Data_t y, - const TensorDescriptor& hyDesc, - Data_t hy, - const TensorDescriptor& cyDesc, - Data_t cy, - Data_t workSpace, - size_t workSpaceSize, - Data_t reserveSpace, - size_t reserveSpaceSize) const +// Assuming sequence length is set to > 0 otherwise throw exception. +void RNNDescriptor::RNNForwardInference(Handle& handle, + const int seqLen, + c_array_view xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& cxDesc, + ConstData_t cx, + const TensorDescriptor& wDesc, + ConstData_t w, + c_array_view yDesc, + Data_t y, + const TensorDescriptor& hyDesc, + Data_t hy, + const TensorDescriptor& cyDesc, + Data_t cy, + Data_t workSpace, + size_t workSpaceSize) const { - if(x == nullptr || w == nullptr || y == nullptr) { MIOPEN_THROW(miopenStatusBadParm); @@ -2268,12 +1498,6 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, { MIOPEN_THROW(miopenStatusBadParm); } - - if(reserveSpaceSize < GetReserveSize(handle, seqLen, xDesc)) - { - MIOPEN_THROW("Reservespace is required"); - } - if(workSpaceSize < GetWorkspaceSize(handle, seqLen, xDesc)) { MIOPEN_THROW("Workspace is required"); @@ -2295,35 +1519,38 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, if(paddingMode == miopenRNNIONotPadded) { - return RNNForwardTrainingPackedTensors(handle, - seqLen, - xDesc, - x, - hxDesc, - hx, - cxDesc, - cx, - wDesc, - w, - yDesc, - y, - hyDesc, - hy, - cyDesc, - cy, - reserveSpace, - reserveSpaceSize); + return RNNForwardInferencePacked(handle, + seqLen, + xDesc, + x, + hxDesc, + hx, + cxDesc, + cx, + wDesc, + w, + yDesc, + y, + hyDesc, + hy, + cyDesc, + cy, + workSpace, + workSpaceSize); } else { Data_t packedXIn = workSpace; - size_t packedXInSize; - std::tie(packedXInSize, std::ignore) = + size_t packedXInSize, packedYOutSize; + std::tie(packedXInSize, packedYOutSize) = RNNTensorPaddingConverter::GetTempPackedBuffersSpace(*this, xDesc); Data_t packedYOut = static_cast(reinterpret_cast(workSpace) + packedXInSize); + auto shifted_workSpace = static_cast(reinterpret_cast(workSpace) + + (packedXInSize + packedYOutSize)); + auto shifted_workSpace_size = workSpaceSize - (packedXInSize + packedYOutSize); // std::vector packed_desc; // std::vector packed_desc_ptrs; // RNNTensorPaddingConverter::CreatePackedDescriptor() @@ -2351,24 +1578,24 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, RNNDescriptor packedRnnDesc(*this); packedRnnDesc.SetPaddingmode(miopenRNNIONotPadded); - packedRnnDesc.RNNForwardTrainingPackedTensors(handle, - seqLen, - xDesc, - packedXIn, - hxDesc, - hx, - cxDesc, - cx, - wDesc, - w, - yDesc, - packedYOut, - hyDesc, - hy, - cyDesc, - cy, - reserveSpace, - reserveSpaceSize); + packedRnnDesc.RNNForwardInferencePacked(handle, + seqLen, + xDesc, + packedXIn, + hxDesc, + hx, + cxDesc, + cx, + wDesc, + w, + yDesc, + packedYOut, + hyDesc, + hy, + cyDesc, + cy, + shifted_workSpace, + shifted_workSpace_size); RNNTensorPaddingConverter::ConvertTensorData( handle, yDesc[0], in_n, packedYOut, y, false); @@ -2391,49 +1618,37 @@ void RNNDescriptor::RNNForwardTraining(Handle& handle, handle.AccumKernelTime(eventTime_mS); } #endif -}; - -void RNNDescriptor::RNNForwardTrainingPackedTensors( - Handle& handle, - const int seqLen, - c_array_view xDesc, - ConstData_t x, - const TensorDescriptor& hxDesc, - ConstData_t hx, - const TensorDescriptor& cxDesc, - ConstData_t cx, - const TensorDescriptor& wDesc, - ConstData_t w, - c_array_view yDesc, - Data_t y, - const TensorDescriptor& hyDesc, - Data_t hy, - const TensorDescriptor& cyDesc, - Data_t cy, - Data_t reserveSpace, - size_t reserveSpaceSize) const +} +void RNNDescriptor::RNNForwardInferencePacked(Handle& handle, + const int seqLen, + c_array_view xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& cxDesc, + ConstData_t cx, + const TensorDescriptor& wDesc, + ConstData_t w, + c_array_view yDesc, + Data_t y, + const TensorDescriptor& hyDesc, + Data_t hy, + const TensorDescriptor& cyDesc, + Data_t cy, + Data_t workSpace, + size_t workSpaceSize) const { - (void)cxDesc; (void)cyDesc; -#if MIOPEN_USE_GEMM - -#if MIOPEN_BACKEND_HIP - HipEventPtr start = nullptr; - HipEventPtr stop = nullptr; - bool is_profiling = handle.IsProfilingEnabled(); + (void)hxDesc; + (void)cxDesc; - if(is_profiling) - { - handle.EnableProfiling(false); - RNNProfilingBegin(handle, start, stop); - } -#endif +#if MIOPEN_USE_GEMM - // OCL legacy float ctime = 0.; // reset kernel timer profileRNNkernels(handle, 0, ctime); + std::vector in_n; int in_h = xDesc[0].GetLengths()[1]; // input vector size int hy_d = hyDesc.GetLengths()[0]; // biNumLayers int hy_n = hyDesc.GetLengths()[1]; // max batch size @@ -2463,12 +1678,11 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( } int batch_n = 0; - std::vector in_n; for(int i = 0; i < seqLen; i++) { - int batchval, batchvalout; - std::tie(batchval, std::ignore) = miopen::tien<2>(xDesc[i].GetLengths()); - std::tie(batchvalout, std::ignore) = miopen::tien<2>(yDesc[i].GetLengths()); + int batchval, inputvec, batchvalout, outputvec; + std::tie(batchval, inputvec) = miopen::tien<2>(xDesc[i].GetLengths()); + std::tie(batchvalout, outputvec) = miopen::tien<2>(yDesc[i].GetLengths()); if(batchval != batchvalout) { MIOPEN_THROW(miopenStatusBadParm, @@ -2495,40 +1709,6 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( batch_n += batchval; } // input check end - bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); -#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP - - if(rnnMode == miopenLSTM && algoMode == miopenRNNdefault && !use_dropout && nLayers > 1 && - dirMode == miopenRNNunidirection && inputMode != miopenRNNskip && - !(miopen::IsDisabled(ENV(MIOPEN_RNNFWD_exp))) && xDesc[0].GetType() == miopenFloat && - seqLen >= 32) - { - RNNForwardTraining_MS(handle, - in_n, - xDesc[0], - x, - hxDesc, - hx, - cx, - wDesc, - w, - yDesc[0], - y, - hy, - cy, - reserveSpace, - reserveSpaceSize); - - if(is_profiling) - { - float eventTime_mS = RNNProfilingEnd(handle, start, stop); - handle.EnableProfiling(true); - handle.ResetKernelTime(); - handle.AccumKernelTime(eventTime_mS); - } - return; - } -#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP int in_stride = xDesc[0].GetLengths()[1]; int hy_stride = hy_h * bi * static_cast(workspaceScale); @@ -2546,11 +1726,11 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( x_stride(3, 1), y_size(3, 1), y_stride(3, 1), hx_size(3, 1), hx_stride(3, 1); miopen::TensorDescriptor sp_desc, w_desc, x_desc, y_desc, hx_desc; - sp_size[2] = reserveSpaceSize / GetTypeSize(wDesc.GetType()); + sp_size[2] = workSpaceSize / GetTypeSize(wDesc.GetType()); sp_stride[0] = sp_size[2]; sp_stride[1] = sp_size[2]; sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - SetTensor(handle, sp_desc, reserveSpace, &beta); + SetTensor(handle, sp_desc, workSpace, &beta); // Update time profileRNNkernels(handle, 1, ctime); sp_stride[0] = batch_n * hy_stride; @@ -2592,17 +1772,17 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( { case miopenRNNRELU: case miopenRNNTANH: - // printf("run rnn gpu fwd \n"); + // printf("run rnn gpu inference \n"); wei_len = hy_h; - hid_off = static_cast(nLayers) * batch_n * hy_stride; + hid_off = 0; break; case miopenLSTM: - // printf("run lstm gpu fwd \n"); + // printf("run lstm gpu inference \n"); wei_len = hy_h * 4; hid_off = bi * hy_h * 5; break; case miopenGRU: - // printf("run gru gpu fwd \n"); + // printf("run gru gpu inference \n"); wei_len = hy_h * 3; hid_off = bi * hy_h * 3; break; @@ -2640,33 +1820,34 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( for(int gi = 0; gi < nHiddenTensorsPerLayer * bi; gi++) { - CopyTensor(handle, x_desc, x, sp_desc, reserveSpace, 0, gi * hy_h); + CopyTensor(handle, x_desc, x, sp_desc, workSpace, 0, gi * hy_h); // Update time profileRNNkernels(handle, 1, ctime); } } else { - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - batch_n, - wei_len * bi, - in_h, - in_stride, - in_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; + miopen::GemmDescriptor gemm_desc = + GemmDescriptor{false, + false, + true, + batch_n, + wei_len * bi, + in_h, + in_stride, + in_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; // RNN does not support determinism miopenStatus_t gemm_status = CallGemm( - handle, gemm_desc, x, 0, w, 0, reserveSpace, hid_shift, GemmBackend_t::rocblas); + handle, gemm_desc, x, 0, w, 0, workSpace, hid_shift, GemmBackend_t::rocblas); if(gemm_status != miopenStatusSuccess) { @@ -2688,76 +1869,32 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( wei_shift = (in_h + hy_h) * wei_stride + (li - 1) * (bi * hy_h + hy_h) * wei_stride; prelayer_shift = (li - 1) * batch_n * hy_stride + hid_off; - if(use_dropout) - { - std::vector drop_size(2), drop_in_str(2, 1), drop_out_str(2, 1); - drop_size[0] = batch_n; - drop_size[1] = hy_h * bi; - drop_in_str[0] = hy_stride; - drop_out_str[0] = hy_h * bi; - - auto drop_in_desc = - miopen::TensorDescriptor(wDesc.GetType(), drop_size, drop_in_str); - auto drop_out_desc = - miopen::TensorDescriptor(wDesc.GetType(), drop_size, drop_out_str); - - size_t drop_rsv_size = drop_out_desc.GetElementSize(); - size_t drop_rsv_start = - algoMode == miopenRNNdefault && rnnMode == miopenLSTM - ? nLayers * batch_n * hy_stride + nLayers * batch_n * hy_h * bi - : 2 * nLayers * batch_n * hy_stride; - - size_t drop_in_offset = prelayer_shift; - size_t drop_out_offset = - drop_rsv_start + (static_cast(li) - 1) * batch_n * hy_h * bi; - size_t drop_rsv_offset = (drop_rsv_start + (nLayers - 1) * batch_n * hy_h * bi) * - (wDesc.GetType() == miopenFloat ? 4 : 2) + - (li - 1) * drop_rsv_size; - - miopen::deref(dropoutDesc) - .DropoutForward(handle, - drop_in_desc, - drop_in_desc, - reserveSpace, - drop_out_desc, - reserveSpace, - reserveSpace, - drop_rsv_size, - drop_in_offset, - drop_out_offset, - drop_rsv_offset); - // Update time - profileRNNkernels(handle, 1, ctime); - prelayer_shift = drop_out_offset; - } - - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - batch_n, - wei_len * bi, - hy_h * bi, - use_dropout ? hy_h * bi : hy_stride, - bi_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; - - miopenStatus_t gemm_status = CallGemm(handle, - gemm_desc, - reserveSpace, - prelayer_shift, - w, - wei_shift, - reserveSpace, - hid_shift, - GemmBackend_t::rocblas); + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + batch_n, + wei_len * bi, + hy_h * bi, + hy_stride, + bi_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + workSpace, + prelayer_shift, + w, + wei_shift, + workSpace, + hid_shift, + GemmBackend_t::rocblas); if(gemm_status != miopenStatusSuccess) { @@ -2791,13 +1928,13 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( miopenTensorOpAdd, &alpha0, sp_desc, - reserveSpace, + workSpace, &alpha1, w_desc, w, &beta_t, sp_desc, - reserveSpace, + workSpace, hid_shift, wei_shift_bias_temp, hid_shift, @@ -2819,24 +1956,25 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( { CopyTensor(handle, sp_desc, - reserveSpace, + workSpace, sp_desc, - reserveSpace, + workSpace, hid_shift + bs * wei_len + 2 * hy_h, hid_shift + hid_off + bs * hy_h); // Update time profileRNNkernels(handle, 1, ctime); + OpTensor(handle, miopenTensorOpAdd, &alpha0, sp_desc, - reserveSpace, + workSpace, &alpha1, sp_desc, - reserveSpace, + workSpace, &beta_t, sp_desc, - reserveSpace, + workSpace, hid_shift + bs * wei_len + 2 * hy_h, hid_shift + bs * wei_len + 2 * hy_h, hid_shift + bs * wei_len + 2 * hy_h); @@ -2859,1063 +1997,3229 @@ void RNNDescriptor::RNNForwardTrainingPackedTensors( sp_size[2] = wei_stride; sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - reserveSpace, - hid_shift, - wei_shift_bias_temp, - hid_shift, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - else - { - sp_size[1] = batch_n - in_n.at(0); - sp_size[2] = wei_len; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - w_size[1] = 1; - w_size[2] = wei_len; - w_desc = miopen::TensorDescriptor(wDesc.GetType(), w_size, w_stride); + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + workSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + workSpace, + hid_shift, + wei_shift_bias_temp, + hid_shift, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else + { + sp_size[1] = batch_n - in_n.at(0); + sp_size[2] = wei_len; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + w_size[1] = 1; + w_size[2] = wei_len; + w_desc = miopen::TensorDescriptor(wDesc.GetType(), w_size, w_stride); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + workSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + workSpace, + hid_shift + in_n.at(0) * hy_stride, + wei_shift_bias_temp, + hid_shift + in_n.at(0) * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + if(dirMode != 0u) + { + if(in_n.at(0) == in_n.at(seqLen - 1)) + { + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + workSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + workSpace, + hid_shift + wei_len, + wei_shift_bias_temp + wei_len, + hid_shift + wei_len, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else + { + int cur_batch = 0; + for(int ti = 0; ti < seqLen; ti++) + { + if(ti != (seqLen - 1)) + { + offset = hid_shift + cur_batch * hy_stride; + + sp_size[1] = in_n.at(ti + 1); + sp_size[2] = wei_len; + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + workSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + workSpace, + offset + wei_len, + wei_shift_bias_temp + wei_len, + offset + wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + } + cur_batch += in_n.at(ti); + } + } + } + } + } + + // from hidden state + int bacc = 0; + int baccbi = batch_n; + for(int ti = 0; ti < seqLen; ti++) + { + baccbi -= in_n.at(seqLen - 1 - ti); + wei_shift = in_h * wei_stride + li * (bi * hy_h + hy_h) * wei_stride; + int pretime_shift = 0; + int use_time = 0; + + for(int ri = 0; ri < bi; ri++) + { + int cur_time = ri == 0 ? ti : seqLen - 1 - ti; + int cur_batch = ri == 0 ? bacc : baccbi; + offset = hid_shift + cur_batch * hy_stride; + if(ti > 0) + { + pretime_shift = + ri == 0 ? hid_shift + (bacc - in_n.at(ti - 1)) * hy_stride + : hid_shift + (baccbi + in_n.at(seqLen - 1 - ti)) * hy_stride; + use_time = ri == 0 ? ti : seqLen - ti; + } + + if(in_n.at(cur_time) > 0) + { + if(ti == 0) + { + if(hx != nullptr) + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + in_n.at(cur_time), + wei_len, + hy_h, + uni_stride, + uni_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, + hx, + hx_shift + ri * hy_n * hy_h, + w, + wei_shift + ri * wei_len * uni_stride, + workSpace, + static_cast(offset) + ri * wei_len, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) + { + miopen::GemmDescriptor gemm_desc = + GemmDescriptor{false, + false, + true, + (in_n.at(cur_time) - in_n.at(use_time)), + wei_len, + hy_h, + uni_stride, + uni_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, + hx, + hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, + w, + wei_shift + ri * wei_len * uni_stride, + workSpace, + static_cast(offset) + ri * wei_len + + in_n.at(use_time) * hy_stride, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + + if(in_n.at(use_time) > 0) + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + in_n.at(use_time), + wei_len, + hy_h, + hy_stride, + uni_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, + workSpace, + pretime_shift + hid_off + ri * hy_h, + w, + wei_shift + ri * wei_len * uni_stride, + workSpace, + static_cast(offset) + ri * wei_len, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + + // update hidden status + sp_size[1] = in_n.at(cur_time); + if(rnnMode == miopenRNNRELU || rnnMode == miopenRNNTANH) + { + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + activDesc.Forward(handle, + &alpha, + sp_desc, + workSpace, + &beta, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + offset + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else if(rnnMode == miopenLSTM) + { + if(algoMode == miopenRNNdefault) + { + LSTMForwardHiddenStateUpdate( + handle, + wDesc.GetType(), + true, + ti == 0, + ri, + in_n.at(0), + in_n.at(cur_time), + in_n.at(use_time), + hy_h, + hy_stride, + wei_len, + wei_stride, + cx, + hx_shift + ri * hy_n * hy_h, + workSpace, + offset + static_cast(ri) * wei_len, + offset + hy_h + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + pretime_shift + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + 0, + offset + hid_off + static_cast(ri) * hy_h); + + // Update time + profileRNNkernels(handle, 1, ctime); + continue; + } + + // active gate i, f, o + sp_size[2] = hy_h * 3; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + sigDesc.Forward(handle, + &alpha, + sp_desc, + workSpace, + &beta, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + offset + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + // active gate c + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + tanhDesc.Forward(handle, + &alpha, + sp_desc, + workSpace, + &beta, + sp_desc, + workSpace, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + // update cell state + alpha0 = 1; + alpha1 = 1; + beta_t = 1; + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + if(ti == 0) + { + if(cx != nullptr) + { + hx_size[1] = in_n.at(cur_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + hx_desc, + cx, + &beta_t, + sp_desc, + workSpace, + offset + hy_h + static_cast(ri) * wei_len, + hx_shift + ri * hy_n * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + if(ri == 1 && cx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) + { + hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + hx_desc, + cx, + &beta_t, + sp_desc, + workSpace, + offset + hy_h + static_cast(ri) * wei_len + + static_cast(in_n.at(use_time)) * hy_stride, + hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h + + static_cast(in_n.at(use_time)) * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + sp_size[1] = in_n.at(cur_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + } + + if(in_n.at(use_time) > 0) + { + if(in_n.at(use_time) != in_n.at(cur_time)) + { + sp_size[1] = in_n.at(use_time); + sp_desc = miopen::TensorDescriptor( + wDesc.GetType(), sp_size, sp_stride); + } + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + hy_h + static_cast(ri) * wei_len, + pretime_shift + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + if(in_n.at(use_time) != in_n.at(cur_time)) + { + sp_size[1] = in_n.at(cur_time); + sp_desc = miopen::TensorDescriptor( + wDesc.GetType(), sp_size, sp_stride); + } + } + } + + // active cell state + tanhDesc.Forward(handle, + &alpha, + sp_desc, + workSpace, + &beta, + sp_desc, + workSpace, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + // update hidden state + beta_t = 0; + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + hid_off + static_cast(ri) * hy_h, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else if(rnnMode == miopenGRU) + { + // active z, r gate + sp_size[2] = 2 * hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + sigDesc.Forward(handle, + &alpha, + sp_desc, + workSpace, + &beta, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + offset + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + // calculate c gate + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + alpha0 = 1; + alpha1 = 1; + beta_t = 0; + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + hy_h + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + hid_off + static_cast(ri) * hy_h, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + // active c gate + tanhDesc.Forward(handle, + &alpha, + sp_desc, + workSpace, + &beta, + sp_desc, + workSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + // calculate hidden state + alpha0 = -1; + alpha1 = 1; + beta_t = 0; + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + alpha0 = 1; + alpha1 = 1; + beta_t = 0; + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + hid_off + static_cast(ri) * hy_h, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + alpha0 = 1; + alpha1 = 1; + beta_t = 1; + if(ti == 0) + { + if(hx != nullptr) + { + hx_size[1] = in_n.at(cur_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + hx_desc, + hx, + &beta_t, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + hx_shift + ri * hy_n * hy_h, + offset + hid_off + static_cast(ri) * hy_h, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) + { + hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + hx_desc, + hx, + &beta_t, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len + + static_cast(in_n.at(use_time)) * hy_stride, + hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, + offset + hid_off + static_cast(ri) * hy_h + + static_cast(in_n.at(use_time)) * hy_stride, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + + sp_size[1] = in_n.at(cur_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + } + + if(in_n.at(use_time) > 0) + { + if(in_n.at(use_time) != in_n.at(cur_time)) + { + sp_size[1] = in_n.at(use_time); + sp_desc = miopen::TensorDescriptor( + wDesc.GetType(), sp_size, sp_stride); + } + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + workSpace, + &alpha1, + sp_desc, + workSpace, + &beta_t, + sp_desc, + workSpace, + offset + static_cast(ri) * wei_len, + pretime_shift + hid_off + ri * hy_h, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + } + } + } + + bacc += in_n.at(ti); + } + + // update hy, cy + if(hy != nullptr || (rnnMode == miopenLSTM && cy != nullptr)) + { + hx_size[2] = hy_h; + sp_size[2] = hy_h; + + bacc = batch_n; + baccbi = 0; + for(int ti = seqLen - 1; ti >= 0; ti--) + { + bacc -= in_n.at(ti); + for(int ri = 0; ri < bi; ri++) + { + int cur_time = ri == 0 ? ti : seqLen - 1 - ti; + int cur_batch = ri == 0 ? bacc : baccbi; + int use_batch = 0; + + if(ti < seqLen - 1) + { + int use_time = ri == 0 ? ti + 1 : seqLen - 2 - ti; + use_batch = in_n.at(use_time); + } + + if(in_n.at(cur_time) > use_batch) + { + offset = hid_shift + cur_batch * hy_stride; + + sp_size[1] = in_n.at(cur_time) - use_batch; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + hx_size[1] = sp_size[1]; + hx_desc = miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + if(hy != nullptr) + { + CopyTensor(handle, + sp_desc, + workSpace, + hx_desc, + hy, + static_cast(offset) + hid_off + ri * hy_h + + use_batch * hy_stride, + hx_shift + ri * hy_n * hy_h + use_batch * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + + if(rnnMode == miopenLSTM && cy != nullptr) + { + CopyTensor(handle, + sp_desc, + workSpace, + hx_desc, + cy, + static_cast(offset) + bi * wei_len + ri * hy_h + + use_batch * hy_stride, + hx_shift + ri * hy_n * hy_h + use_batch * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + } + baccbi += in_n.at(seqLen - 1 - ti); + } + } + } + + // output + prelayer_shift = (static_cast(nLayers) - 1) * batch_n * hy_stride + hid_off; + + sp_size[1] = batch_n; + sp_size[2] = hy_h * bi; + y_size[1] = batch_n; + y_size[2] = out_h; + y_desc = miopen::TensorDescriptor(wDesc.GetType(), y_size, y_stride); + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + CopyTensor(handle, sp_desc, workSpace, y_desc, y, prelayer_shift, 0); + // Update time + profileRNNkernels(handle, 2, ctime); + +#else + (void)hx; + (void)cx; + (void)handle; + (void)seqLen; + (void)xDesc; + (void)x; + (void)w; + (void)y; + (void)hyDesc; + (void)hy; + (void)yDesc; + (void)wDesc; + (void)workSpaceSize; + (void)workSpace; + MIOPEN_THROW("GEMM is not supported"); +#endif +} + +void RNNDescriptor::RNNForwardTraining(Handle& handle, + const int seqLen, + c_array_view xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& cxDesc, + ConstData_t cx, + const TensorDescriptor& wDesc, + ConstData_t w, + c_array_view yDesc, + Data_t y, + const TensorDescriptor& hyDesc, + Data_t hy, + const TensorDescriptor& cyDesc, + Data_t cy, + Data_t workSpace, + size_t workSpaceSize, + Data_t reserveSpace, + size_t reserveSpaceSize) const +{ + + if(x == nullptr || w == nullptr || y == nullptr) + { + MIOPEN_THROW(miopenStatusBadParm); + } + if(hxDesc.GetSize() != cxDesc.GetSize() || hxDesc.GetSize() != hyDesc.GetSize() || + hxDesc.GetSize() != cyDesc.GetSize()) + { + MIOPEN_THROW(miopenStatusBadParm); + } + + if(reserveSpaceSize < GetReserveSize(handle, seqLen, xDesc)) + { + MIOPEN_THROW("Reservespace is required"); + } + + if(workSpaceSize < GetWorkspaceSize(handle, seqLen, xDesc)) + { + MIOPEN_THROW("Workspace is required"); + } + +#if MIOPEN_BACKEND_HIP + HipEventPtr start = nullptr; + HipEventPtr stop = nullptr; + bool is_profiling = handle.IsProfilingEnabled(); + + if(is_profiling) + { + handle.EnableProfiling(false); + RNNProfilingBegin(handle, start, stop); + } + try + { +#endif + + if(paddingMode == miopenRNNIONotPadded) + { + return RNNForwardTrainingPackedTensors(handle, + seqLen, + xDesc, + x, + hxDesc, + hx, + cxDesc, + cx, + wDesc, + w, + yDesc, + y, + hyDesc, + hy, + cyDesc, + cy, + reserveSpace, + reserveSpaceSize); + } + else + { + Data_t packedXIn = workSpace; + size_t packedXInSize; + std::tie(packedXInSize, std::ignore) = + RNNTensorPaddingConverter::GetTempPackedBuffersSpace(*this, xDesc); + + Data_t packedYOut = + static_cast(reinterpret_cast(workSpace) + packedXInSize); + + // std::vector packed_desc; + // std::vector packed_desc_ptrs; + // RNNTensorPaddingConverter::CreatePackedDescriptor() + // for future developments: as long as we don't use strides from xDesc and yDesc + // we ignoring conversion of this descriptors. + std::vector in_n(seqLen); + + for(int i = 0; i < seqLen; i++) + { + int batchval, batchvalout; + std::tie(batchval, std::ignore) = miopen::tien<2>(xDesc[i].GetLengths()); + std::tie(batchvalout, std::ignore) = miopen::tien<2>(yDesc[i].GetLengths()); + if(batchval != batchvalout) + { + MIOPEN_THROW(miopenStatusBadParm, + "Input batch length: " + std::to_string(batchval) + + ", Output batch length: " + std::to_string(batchvalout)); + } + in_n[i] = batchval; + } + + RNNTensorPaddingConverter::ConvertTensorData( + handle, xDesc[0], in_n, x, packedXIn, true); + + RNNDescriptor packedRnnDesc(*this); + packedRnnDesc.SetPaddingmode(miopenRNNIONotPadded); + + packedRnnDesc.RNNForwardTrainingPackedTensors(handle, + seqLen, + xDesc, + packedXIn, + hxDesc, + hx, + cxDesc, + cx, + wDesc, + w, + yDesc, + packedYOut, + hyDesc, + hy, + cyDesc, + cy, + reserveSpace, + reserveSpaceSize); + + RNNTensorPaddingConverter::ConvertTensorData( + handle, yDesc[0], in_n, packedYOut, y, false); + } + +#if MIOPEN_BACKEND_HIP + } + catch(...) + { + if(is_profiling) + handle.EnableProfiling(true); + throw; + } + + if(is_profiling) + { + float eventTime_mS = RNNProfilingEnd(handle, start, stop); + handle.EnableProfiling(true); + handle.ResetKernelTime(); + handle.AccumKernelTime(eventTime_mS); + } +#endif +}; + +void RNNDescriptor::RNNForwardTrainingPackedTensors( + Handle& handle, + const int seqLen, + c_array_view xDesc, + ConstData_t x, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& cxDesc, + ConstData_t cx, + const TensorDescriptor& wDesc, + ConstData_t w, + c_array_view yDesc, + Data_t y, + const TensorDescriptor& hyDesc, + Data_t hy, + const TensorDescriptor& cyDesc, + Data_t cy, + Data_t reserveSpace, + size_t reserveSpaceSize) const +{ + (void)cxDesc; + (void)cyDesc; +#if MIOPEN_USE_GEMM + +#if MIOPEN_BACKEND_HIP + HipEventPtr start = nullptr; + HipEventPtr stop = nullptr; + bool is_profiling = handle.IsProfilingEnabled(); + + if(is_profiling) + { + handle.EnableProfiling(false); + RNNProfilingBegin(handle, start, stop); + } +#endif + + // OCL legacy + float ctime = 0.; + // reset kernel timer + profileRNNkernels(handle, 0, ctime); + + int in_h = xDesc[0].GetLengths()[1]; // input vector size + int hy_d = hyDesc.GetLengths()[0]; // biNumLayers + int hy_n = hyDesc.GetLengths()[1]; // max batch size + int hy_h = hyDesc.GetLengths()[2]; // hidden size + int out_h = yDesc[0].GetLengths()[1]; // output vector size + int bi = dirMode != 0u ? 2 : 1; + + if(in_h <= 0 || hy_h <= 0 || hy_n <= 0 || hy_d <= 0 || out_h <= 0 || seqLen <= 0) + { + MIOPEN_THROW(miopenStatusBadParm); + } + + if(out_h != (bi * hy_h)) + { + MIOPEN_THROW(miopenStatusBadParm, "Output size doesn't match hidden state size!"); + } + + if(inputMode == miopenRNNskip) + { + if(in_h != hy_h) + { + MIOPEN_THROW(miopenStatusBadParm, + "The input tensor size must equal to the hidden " + "state size of the network in SKIP_INPUT mode!"); + } + in_h = 0; + } + + int batch_n = 0; + std::vector in_n; + for(int i = 0; i < seqLen; i++) + { + int batchval, batchvalout; + std::tie(batchval, std::ignore) = miopen::tien<2>(xDesc[i].GetLengths()); + std::tie(batchvalout, std::ignore) = miopen::tien<2>(yDesc[i].GetLengths()); + if(batchval != batchvalout) + { + MIOPEN_THROW(miopenStatusBadParm, + "Input batch length: " + std::to_string(batchval) + + ", Output batch length: " + std::to_string(batchvalout)); + } + if(i == 0) + { + if(batchval <= 0) + { + MIOPEN_THROW(miopenStatusBadParm, "Input batch is ZERO!"); + } + } + else + { + if(batchval > in_n.back() || batchval < 0) + { + MIOPEN_THROW(miopenStatusBadParm, + "Incorrect input batch size at time " + std::to_string(i) + + "! Batch size must not ascend!"); + } + } + in_n.push_back(batchval); + batch_n += batchval; + } + // input check end + bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); +#if MIOPEN_USE_GEMM && MIOPEN_BACKEND_HIP + + if(rnnMode == miopenLSTM && algoMode == miopenRNNdefault && !use_dropout && nLayers > 1 && + dirMode == miopenRNNunidirection && inputMode != miopenRNNskip && + !(miopen::IsDisabled(ENV(MIOPEN_RNNFWD_exp))) && xDesc[0].GetType() == miopenFloat && + seqLen >= 32) + { + RNNForwardTraining_MS(handle, + in_n, + xDesc[0], + x, + hxDesc, + hx, + cx, + wDesc, + w, + yDesc[0], + y, + hy, + cy, + reserveSpace, + reserveSpaceSize); + + if(is_profiling) + { + float eventTime_mS = RNNProfilingEnd(handle, start, stop); + handle.EnableProfiling(true); + handle.ResetKernelTime(); + handle.AccumKernelTime(eventTime_mS); + } + return; + } + + if((rnnMode == miopenRNNRELU || rnnMode == miopenRNNTANH) && !use_dropout && + inputMode != miopenRNNskip && !(miopen::IsDisabled(ENV(MIOPEN_RNNFWD_exp)))) + { + + RNNForwardTrainingTanhRelu(handle, + in_n, + xDesc[0], + x, + hxDesc, + hx, + wDesc, + w, + yDesc[0], + y, + hyDesc, + hy, + reserveSpace, + reserveSpaceSize); + if(is_profiling) + { + float eventTime_mS = RNNProfilingEnd(handle, start, stop); + handle.EnableProfiling(true); + handle.ResetKernelTime(); + handle.AccumKernelTime(eventTime_mS); + } + return; + } + +#endif // MIOPEN_USE_GEMM&& MIOPEN_BACKEND_HIP + + int in_stride = xDesc[0].GetLengths()[1]; + int hy_stride = hy_h * bi * static_cast(workspaceScale); + int out_stride = out_h; + int wei_stride = hy_h * bi * static_cast(nHiddenTensorsPerLayer); + int uni_stride = hy_h; + int bi_stride = hy_h * bi; + + size_t wei_shift_bias = (in_h + hy_h + (bi * hy_h + hy_h) * (nLayers - 1)) * wei_stride; + size_t offset; + float alpha0, alpha1, beta_t; + float alpha = 1, beta = 0; + + std::vector sp_size(3, 1), sp_stride(3, 1), w_size(3, 1), w_stride(3, 1), x_size(3, 1), + x_stride(3, 1), y_size(3, 1), y_stride(3, 1), hx_size(3, 1), hx_stride(3, 1); + miopen::TensorDescriptor sp_desc, w_desc, x_desc, y_desc, hx_desc; + + sp_size[2] = reserveSpaceSize / GetTypeSize(wDesc.GetType()); + sp_stride[0] = sp_size[2]; + sp_stride[1] = sp_size[2]; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + SetTensor(handle, sp_desc, reserveSpace, &beta); + // Update time + profileRNNkernels(handle, 1, ctime); + sp_stride[0] = batch_n * hy_stride; + sp_stride[1] = hy_stride; + sp_size[2] = 1; + w_stride[0] = wei_stride; + w_stride[1] = wei_stride; + x_stride[0] = batch_n * in_stride; + x_stride[1] = in_stride; + y_stride[0] = batch_n * out_stride; + y_stride[1] = out_stride; + if(hy != nullptr || (rnnMode == miopenLSTM && cy != nullptr)) + { + hx_size[2] = hy_d * hy_n * hy_h; + hx_stride[0] = hx_size[2]; + hx_stride[1] = hx_size[2]; + hx_desc = miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + if(hy != nullptr) + { + SetTensor(handle, hx_desc, hy, &beta); + // Update time + profileRNNkernels(handle, 1, ctime); + } + if(rnnMode == miopenLSTM && cy != nullptr) + { + SetTensor(handle, hx_desc, cy, &beta); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + hx_stride[0] = in_n.at(0) * uni_stride; + hx_stride[1] = uni_stride; + + int wei_shift, prelayer_shift; + int wei_len = 0; + int hid_off = 0; + + switch(rnnMode) + { + case miopenRNNRELU: + case miopenRNNTANH: + // printf("run rnn gpu fwd \n"); + wei_len = hy_h; + hid_off = static_cast(nLayers) * batch_n * hy_stride; + break; + case miopenLSTM: + // printf("run lstm gpu fwd \n"); + wei_len = hy_h * 4; + hid_off = bi * hy_h * 5; + break; + case miopenGRU: + // printf("run gru gpu fwd \n"); + wei_len = hy_h * 3; + hid_off = bi * hy_h * 3; + break; + } + + ActivationDescriptor tanhDesc, sigDesc, activDesc; + sigDesc = {miopenActivationLOGISTIC, 1, 0, 1}; + tanhDesc = {miopenActivationTANH, 1, 1, 1}; + if(rnnMode == miopenRNNRELU) + { + activDesc = {miopenActivationRELU, 1, 0, 1}; + } + else if(rnnMode == miopenRNNTANH) + { + activDesc = {miopenActivationTANH, 1, 1, 1}; + } + + for(int li = 0; li < nLayers; li++) + { + int hid_shift = li * batch_n * hy_stride; + int hx_shift = li * hy_n * bi_stride; + int wei_shift_bias_temp = static_cast(wei_shift_bias) + li * 2 * wei_stride; + + // from input + if(li == 0) + { + if(inputMode == miopenRNNskip) + { + x_size[1] = batch_n; + x_size[2] = hy_h; + sp_size[1] = batch_n; + sp_size[2] = hy_h; + x_desc = miopen::TensorDescriptor(wDesc.GetType(), x_size, x_stride); + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + for(int gi = 0; gi < nHiddenTensorsPerLayer * bi; gi++) + { + CopyTensor(handle, x_desc, x, sp_desc, reserveSpace, 0, gi * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + batch_n, + wei_len * bi, + in_h, + in_stride, + in_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = CallGemm( + handle, gemm_desc, x, 0, w, 0, reserveSpace, hid_shift, GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + wei_shift = (in_h + hy_h) * wei_stride + (li - 1) * (bi * hy_h + hy_h) * wei_stride; + prelayer_shift = (li - 1) * batch_n * hy_stride + hid_off; + + if(use_dropout) + { + std::vector drop_size(2), drop_in_str(2, 1), drop_out_str(2, 1); + drop_size[0] = batch_n; + drop_size[1] = hy_h * bi; + drop_in_str[0] = hy_stride; + drop_out_str[0] = hy_h * bi; + + auto drop_in_desc = + miopen::TensorDescriptor(wDesc.GetType(), drop_size, drop_in_str); + auto drop_out_desc = + miopen::TensorDescriptor(wDesc.GetType(), drop_size, drop_out_str); + + size_t drop_rsv_size = drop_out_desc.GetElementSize(); + size_t drop_rsv_start = + algoMode == miopenRNNdefault && rnnMode == miopenLSTM + ? nLayers * batch_n * hy_stride + nLayers * batch_n * hy_h * bi + : 2 * nLayers * batch_n * hy_stride; + + size_t drop_in_offset = prelayer_shift; + size_t drop_out_offset = + drop_rsv_start + (static_cast(li) - 1) * batch_n * hy_h * bi; + size_t drop_rsv_offset = (drop_rsv_start + (nLayers - 1) * batch_n * hy_h * bi) * + (wDesc.GetType() == miopenFloat ? 4 : 2) + + (li - 1) * drop_rsv_size; + + miopen::deref(dropoutDesc) + .DropoutForward(handle, + drop_in_desc, + drop_in_desc, + reserveSpace, + drop_out_desc, + reserveSpace, + reserveSpace, + drop_rsv_size, + drop_in_offset, + drop_out_offset, + drop_rsv_offset); + // Update time + profileRNNkernels(handle, 1, ctime); + prelayer_shift = drop_out_offset; + } + + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + batch_n, + wei_len * bi, + hy_h * bi, + use_dropout ? hy_h * bi : hy_stride, + bi_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + reserveSpace, + prelayer_shift, + w, + wei_shift, + reserveSpace, + hid_shift, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + + if(biasMode != 0u) + { + alpha0 = 1; + alpha1 = 1; + beta_t = 0; + + w_size[1] = 1; + w_size[2] = wei_stride; + sp_size[1] = batch_n; + sp_size[2] = wei_stride; + w_desc = miopen::TensorDescriptor(wDesc.GetType(), w_size, w_stride); + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + reserveSpace, + hid_shift, + wei_shift_bias_temp, + hid_shift, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + + if(rnnMode == miopenGRU) + { + sp_size[1] = batch_n; + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + alpha0 = 0; + alpha1 = 0; + beta_t = 0; + for(int bs = 0; bs < bi; bs++) + { + CopyTensor(handle, + sp_desc, + reserveSpace, + sp_desc, + reserveSpace, + hid_shift + bs * wei_len + 2 * hy_h, + hid_shift + hid_off + bs * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + hid_shift + bs * wei_len + 2 * hy_h, + hid_shift + bs * wei_len + 2 * hy_h, + hid_shift + bs * wei_len + 2 * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + + if(biasMode != 0u) + { + wei_shift_bias_temp += wei_stride; + + alpha0 = 1; + alpha1 = 1; + beta_t = 0; + + if(hx != nullptr) + { + sp_size[1] = batch_n; + sp_size[2] = wei_stride; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + reserveSpace, + hid_shift, + wei_shift_bias_temp, + hid_shift, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else + { + sp_size[1] = batch_n - in_n.at(0); + sp_size[2] = wei_len; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + w_size[1] = 1; + w_size[2] = wei_len; + w_desc = miopen::TensorDescriptor(wDesc.GetType(), w_size, w_stride); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + reserveSpace, + hid_shift + in_n.at(0) * hy_stride, + wei_shift_bias_temp, + hid_shift + in_n.at(0) * hy_stride, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + + if(dirMode != 0u) + { + if(in_n.at(0) == in_n.at(seqLen - 1)) + { + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + reserveSpace, + hid_shift + wei_len, + wei_shift_bias_temp + wei_len, + hid_shift + wei_len, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else + { + int cur_batch = 0; + for(int ti = 0; ti < seqLen; ti++) + { + if(ti != (seqLen - 1)) + { + offset = hid_shift + cur_batch * hy_stride; + + sp_size[1] = in_n.at(ti + 1); + sp_size[2] = wei_len; + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + w_desc, + w, + &beta_t, + sp_desc, + reserveSpace, + static_cast(offset) + wei_len, + wei_shift_bias_temp + wei_len, + static_cast(offset) + wei_len, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + cur_batch += in_n.at(ti); + } + } + } + } + } + + // from hidden state + int bacc = 0; + int baccbi = batch_n; + for(int ti = 0; ti < seqLen; ti++) + { + baccbi -= in_n.at(seqLen - 1 - ti); + wei_shift = in_h * wei_stride + li * (bi * hy_h + hy_h) * wei_stride; + int pretime_shift = 0; + int use_time = 0; + + for(int ri = 0; ri < bi; ri++) + { + int cur_time = ri == 0 ? ti : seqLen - 1 - ti; + int cur_batch = ri == 0 ? bacc : baccbi; + offset = hid_shift + cur_batch * hy_stride; + if(ti > 0) + { + pretime_shift = + ri == 0 ? hid_shift + (bacc - in_n.at(ti - 1)) * hy_stride + : hid_shift + (baccbi + in_n.at(seqLen - 1 - ti)) * hy_stride; + use_time = ri == 0 ? ti : seqLen - ti; + } + + if(in_n.at(cur_time) > 0) + { + if(ti == 0) + { + if(hx != nullptr) + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + in_n.at(cur_time), + wei_len, + hy_h, + uni_stride, + uni_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, + hx, + hx_shift + ri * hy_n * hy_h, + w, + wei_shift + ri * wei_len * uni_stride, + reserveSpace, + static_cast(offset) + ri * wei_len, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) + { + miopen::GemmDescriptor gemm_desc = + GemmDescriptor{false, + false, + true, + (in_n.at(cur_time) - in_n.at(use_time)), + wei_len, + hy_h, + uni_stride, + uni_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, + hx, + hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, + w, + wei_shift + ri * wei_len * uni_stride, + reserveSpace, + static_cast(offset) + ri * wei_len + + in_n.at(use_time) * hy_stride, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + + if(in_n.at(use_time) > 0) + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + true, + in_n.at(use_time), + wei_len, + hy_h, + hy_stride, + uni_stride, + hy_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + xDesc[0].GetType(), + false}; + + miopenStatus_t gemm_status = + CallGemm(handle, + gemm_desc, + reserveSpace, + pretime_shift + hid_off + ri * hy_h, + w, + wei_shift + ri * wei_len * uni_stride, + reserveSpace, + static_cast(offset) + ri * wei_len, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + + // update hidden status + sp_size[1] = in_n.at(cur_time); + if(rnnMode == miopenRNNRELU || rnnMode == miopenRNNTANH) + { + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + activDesc.Forward(handle, + &alpha, + sp_desc, + reserveSpace, + &beta, + sp_desc, + reserveSpace, + offset + static_cast(ri) * wei_len, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else if(rnnMode == miopenLSTM) + { + if(algoMode == miopenRNNdefault) + { + LSTMForwardHiddenStateUpdate( + handle, + wDesc.GetType(), + false, + ti == 0, + ri, + in_n.at(0), + in_n.at(cur_time), + in_n.at(use_time), + hy_h, + hy_stride, + wei_len, + wei_stride, + cx, + hx_shift + ri * hy_n * hy_h, + reserveSpace, + offset + static_cast(ri) * wei_len, + offset + hy_h + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + pretime_shift + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + (li * batch_n + cur_batch) * bi * hy_h + ri * hy_h + + nLayers * batch_n * hy_stride, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + continue; + } + + // active gate i, f, o + sp_size[2] = hy_h * 3; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + sigDesc.Forward(handle, + &alpha, + sp_desc, + reserveSpace, + &beta, + sp_desc, + reserveSpace, + offset + static_cast(ri) * wei_len, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride); + + // active gate c + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + tanhDesc.Forward(handle, + &alpha, + sp_desc, + reserveSpace, + &beta, + sp_desc, + reserveSpace, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len + + nLayers * batch_n * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + // update cell state + alpha0 = 1; + alpha1 = 1; + beta_t = 1; + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + 3 * static_cast(hy_h) + + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + if(ti == 0) + { + if(cx != nullptr) + { + hx_size[1] = in_n.at(cur_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + hx_desc, + cx, + &beta_t, + sp_desc, + reserveSpace, + offset + hy_h + static_cast(ri) * wei_len + + nLayers * batch_n * hy_stride, + hx_shift + ri * hy_n * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + } + } + else + { + if(ri == 1 && cx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) + { + hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + hx_desc, + cx, + &beta_t, + sp_desc, + reserveSpace, + offset + hy_h + static_cast(ri) * wei_len + + static_cast(in_n.at(use_time)) * hy_stride + + nLayers * batch_n * hy_stride, + hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h + + static_cast(in_n.at(use_time)) * hy_stride, + true); + // Update time + profileRNNkernels(handle, 1, ctime); + + sp_size[1] = in_n.at(cur_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + } + + if(in_n.at(use_time) > 0) + { + if(in_n.at(use_time) != in_n.at(cur_time)) + { + sp_size[1] = in_n.at(use_time); + sp_desc = miopen::TensorDescriptor( + wDesc.GetType(), sp_size, sp_stride); + } + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + hy_h + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + pretime_shift + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + if(in_n.at(use_time) != in_n.at(cur_time)) + { + sp_size[1] = in_n.at(cur_time); + sp_desc = miopen::TensorDescriptor( + wDesc.GetType(), sp_size, sp_stride); + } + } + } + + // active cell state + tanhDesc.Forward(handle, + &alpha, + sp_desc, + reserveSpace, + &beta, + sp_desc, + reserveSpace, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h + + nLayers * batch_n * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + // update hidden state + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + static_cast(bi) * wei_len + + static_cast(ri) * hy_h + + static_cast(nLayers) * batch_n * hy_stride, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + else if(rnnMode == miopenGRU) + { + // active z, r gate + sp_size[2] = 2 * hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + sigDesc.Forward(handle, + &alpha, + sp_desc, + reserveSpace, + &beta, + sp_desc, + reserveSpace, + offset + static_cast(ri) * wei_len, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + // calculate c gate + sp_size[2] = hy_h; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + CopyTensor(handle, + sp_desc, + reserveSpace, + sp_desc, + reserveSpace, + static_cast(offset) + 2 * hy_h + ri * wei_len, + static_cast(offset) + hid_off + ri * hy_h + + static_cast(nLayers) * batch_n * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + alpha0 = 1; + alpha1 = 1; + beta_t = 0; + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + hy_h + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + hid_off + static_cast(ri) * hy_h, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len); + // Update time + profileRNNkernels(handle, 1, ctime); + + // active c gate + tanhDesc.Forward(handle, + &alpha, + sp_desc, + reserveSpace, + &beta, + sp_desc, + reserveSpace, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride); + // Update time + profileRNNkernels(handle, 1, ctime); + + // calculate hidden state + alpha0 = -1; + alpha1 = 1; + beta_t = 0; - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - w_desc, - w, - &beta_t, - sp_desc, - reserveSpace, - hid_shift + in_n.at(0) * hy_stride, - wei_shift_bias_temp, - hid_shift + in_n.at(0) * hy_stride, - true); - // Update time - profileRNNkernels(handle, 1, ctime); + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + + alpha0 = 1; + alpha1 = 1; + beta_t = 0; - if(dirMode != 0u) - { - if(in_n.at(0) == in_n.at(seqLen - 1)) - { OpTensor(handle, miopenTensorOpAdd, &alpha0, sp_desc, reserveSpace, &alpha1, - w_desc, - w, + sp_desc, + reserveSpace, &beta_t, sp_desc, reserveSpace, - hid_shift + wei_len, - wei_shift_bias_temp + wei_len, - hid_shift + wei_len, - true); + offset + 2 * static_cast(hy_h) + + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + offset + hid_off + static_cast(ri) * hy_h, + offset + hid_off + static_cast(ri) * hy_h); // Update time profileRNNkernels(handle, 1, ctime); - } - else - { - int cur_batch = 0; - for(int ti = 0; ti < seqLen; ti++) + + alpha0 = 1; + alpha1 = 1; + beta_t = 1; + + if(ti == 0) { - if(ti != (seqLen - 1)) + if(hx != nullptr) { - offset = hid_shift + cur_batch * hy_stride; - - sp_size[1] = in_n.at(ti + 1); - sp_size[2] = wei_len; - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + hx_size[1] = in_n.at(cur_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); OpTensor(handle, - miopenTensorOpAdd, + miopenTensorOpMul, &alpha0, sp_desc, reserveSpace, &alpha1, - w_desc, - w, + hx_desc, + hx, &beta_t, sp_desc, reserveSpace, - static_cast(offset) + wei_len, - wei_shift_bias_temp + wei_len, - static_cast(offset) + wei_len, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + hx_shift + ri * hy_n * hy_h, + offset + hid_off + static_cast(ri) * hy_h, true); // Update time profileRNNkernels(handle, 1, ctime); } - cur_batch += in_n.at(ti); - } - } - } - } - } - - // from hidden state - int bacc = 0; - int baccbi = batch_n; - for(int ti = 0; ti < seqLen; ti++) - { - baccbi -= in_n.at(seqLen - 1 - ti); - wei_shift = in_h * wei_stride + li * (bi * hy_h + hy_h) * wei_stride; - int pretime_shift = 0; - int use_time = 0; - - for(int ri = 0; ri < bi; ri++) - { - int cur_time = ri == 0 ? ti : seqLen - 1 - ti; - int cur_batch = ri == 0 ? bacc : baccbi; - offset = hid_shift + cur_batch * hy_stride; - if(ti > 0) - { - pretime_shift = - ri == 0 ? hid_shift + (bacc - in_n.at(ti - 1)) * hy_stride - : hid_shift + (baccbi + in_n.at(seqLen - 1 - ti)) * hy_stride; - use_time = ri == 0 ? ti : seqLen - ti; - } - - if(in_n.at(cur_time) > 0) - { - if(ti == 0) - { - if(hx != nullptr) - { - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - in_n.at(cur_time), - wei_len, - hy_h, - uni_stride, - uni_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; - - miopenStatus_t gemm_status = - CallGemm(handle, - gemm_desc, - hx, - hx_shift + ri * hy_n * hy_h, - w, - wei_shift + ri * wei_len * uni_stride, - reserveSpace, - static_cast(offset) + ri * wei_len, - GemmBackend_t::rocblas); - - if(gemm_status != miopenStatusSuccess) - { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else - { - MIOPEN_LOG_E("GEMM failed"); - } - } - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) - { - miopen::GemmDescriptor gemm_desc = - GemmDescriptor{false, - false, - true, - (in_n.at(cur_time) - in_n.at(use_time)), - wei_len, - hy_h, - uni_stride, - uni_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; - - miopenStatus_t gemm_status = - CallGemm(handle, - gemm_desc, - hx, - hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, - w, - wei_shift + ri * wei_len * uni_stride, - reserveSpace, - static_cast(offset) + ri * wei_len + - in_n.at(use_time) * hy_stride, - GemmBackend_t::rocblas); - - if(gemm_status != miopenStatusSuccess) - { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else - { - MIOPEN_LOG_E("GEMM failed"); - } - } - // Update time - profileRNNkernels(handle, 1, ctime); } - - if(in_n.at(use_time) > 0) + else { - miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, - false, - true, - in_n.at(use_time), - wei_len, - hy_h, - hy_stride, - uni_stride, - hy_stride, - 1, // batch count - 0, // Stride A - 0, // Stride B - 0, // Stride C - 1, // alpha - 1, // beta - xDesc[0].GetType(), - false}; + if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) + { + hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); + hx_size[2] = hy_h; + hx_desc = + miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); - miopenStatus_t gemm_status = - CallGemm(handle, - gemm_desc, + sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, reserveSpace, - pretime_shift + hid_off + ri * hy_h, - w, - wei_shift + ri * wei_len * uni_stride, + &alpha1, + hx_desc, + hx, + &beta_t, + sp_desc, reserveSpace, - static_cast(offset) + ri * wei_len, - GemmBackend_t::rocblas); + offset + static_cast(ri) * wei_len + + static_cast(in_n.at(use_time)) * hy_stride + + static_cast(nLayers) * batch_n * hy_stride, + hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, + offset + hid_off + static_cast(ri) * hy_h + + static_cast(in_n.at(use_time)) * hy_stride, + true); + // Update time + profileRNNkernels(handle, 1, ctime); - if(gemm_status != miopenStatusSuccess) + sp_size[1] = in_n.at(cur_time); + sp_desc = + miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + } + + if(in_n.at(use_time) > 0) { - if(gemm_status == miopenStatusNotImplemented) - { - MIOPEN_LOG_E("GEMM not implemented"); - } - else + if(in_n.at(use_time) != in_n.at(cur_time)) { - MIOPEN_LOG_E("GEMM failed"); + sp_size[1] = in_n.at(use_time); + sp_desc = miopen::TensorDescriptor( + wDesc.GetType(), sp_size, sp_stride); } + + OpTensor(handle, + miopenTensorOpMul, + &alpha0, + sp_desc, + reserveSpace, + &alpha1, + sp_desc, + reserveSpace, + &beta_t, + sp_desc, + reserveSpace, + offset + static_cast(ri) * wei_len + + static_cast(nLayers) * batch_n * hy_stride, + pretime_shift + hid_off + ri * hy_h, + offset + hid_off + static_cast(ri) * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); } - // Update time - profileRNNkernels(handle, 1, ctime); } } + } + } - // update hidden status - sp_size[1] = in_n.at(cur_time); - if(rnnMode == miopenRNNRELU || rnnMode == miopenRNNTANH) - { - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + bacc += in_n.at(ti); + } - activDesc.Forward(handle, - &alpha, - sp_desc, - reserveSpace, - &beta, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + // update hy, cy + if(hy != nullptr || (rnnMode == miopenLSTM && cy != nullptr)) + { + hx_size[2] = hy_h; + sp_size[2] = hy_h; + + bacc = batch_n; + baccbi = 0; + for(int ti = seqLen - 1; ti >= 0; ti--) + { + bacc -= in_n.at(ti); + for(int ri = 0; ri < bi; ri++) + { + int cur_time = ri == 0 ? ti : seqLen - 1 - ti; + int cur_batch = ri == 0 ? bacc : baccbi; + int use_batch = 0; + + if(ti < seqLen - 1) + { + int use_time = ri == 0 ? ti + 1 : seqLen - 2 - ti; + use_batch = in_n.at(use_time); } - else if(rnnMode == miopenLSTM) + + if(in_n.at(cur_time) > use_batch) { - if(algoMode == miopenRNNdefault) + offset = hid_shift + cur_batch * hy_stride; + + sp_size[1] = in_n.at(cur_time) - use_batch; + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + hx_size[1] = sp_size[1]; + hx_desc = miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + + if(hy != nullptr) { - LSTMForwardHiddenStateUpdate( - handle, - wDesc.GetType(), - false, - ti == 0, - ri, - in_n.at(0), - in_n.at(cur_time), - in_n.at(use_time), - hy_h, - hy_stride, - wei_len, - wei_stride, - cx, - hx_shift + ri * hy_n * hy_h, - reserveSpace, - offset + static_cast(ri) * wei_len, - offset + hy_h + static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - pretime_shift + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - (li * batch_n + cur_batch) * bi * hy_h + ri * hy_h + - nLayers * batch_n * hy_stride, - offset + hid_off + static_cast(ri) * hy_h); + CopyTensor(handle, + sp_desc, + reserveSpace, + hx_desc, + hy, + static_cast(offset) + hid_off + ri * hy_h + + use_batch * hy_stride, + hx_shift + ri * hy_n * hy_h + use_batch * hy_h); + // Update time + profileRNNkernels(handle, 1, ctime); + } + + if(rnnMode == miopenLSTM && cy != nullptr) + { + CopyTensor(handle, + sp_desc, + reserveSpace, + hx_desc, + cy, + static_cast(offset) + bi * wei_len + ri * hy_h + + use_batch * hy_stride, + hx_shift + ri * hy_n * hy_h + use_batch * hy_h); // Update time profileRNNkernels(handle, 1, ctime); - continue; } + } + } + baccbi += in_n.at(seqLen - 1 - ti); + } + } + } + + // output + prelayer_shift = (static_cast(nLayers) - 1) * batch_n * hy_stride + hid_off; + + sp_size[1] = batch_n; + sp_size[2] = hy_h * bi; + y_size[1] = batch_n; + y_size[2] = out_h; + y_desc = miopen::TensorDescriptor(wDesc.GetType(), y_size, y_stride); + sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + + CopyTensor(handle, sp_desc, reserveSpace, y_desc, y, prelayer_shift, 0); + // Update time + profileRNNkernels(handle, 2, ctime); + +#if MIOPEN_BACKEND_HIP + if(is_profiling) + { + float eventTime_mS = RNNProfilingEnd(handle, start, stop); + handle.EnableProfiling(true); + handle.ResetKernelTime(); + handle.AccumKernelTime(eventTime_mS); + } +#endif + +#else + (void)handle; + (void)seqLen; + (void)xDesc; + (void)x; + (void)w; + (void)hx; + (void)cx; + (void)y; + (void)hyDesc; + (void)hy; + (void)yDesc; + (void)cy; + (void)hxDesc; + (void)wDesc; + (void)reserveSpace; + (void)reserveSpaceSize; + MIOPEN_THROW("GEMM is not supported"); +#endif +}; + +void RNNDescriptor::RNNBackwardData(Handle& handle, + const int seqLen, + c_array_view yDesc, + ConstData_t y, + c_array_view dyDesc, + ConstData_t dy, + const TensorDescriptor& dhyDesc, + ConstData_t dhy, + const TensorDescriptor& dcyDesc, + ConstData_t dcy, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& hxDesc, + ConstData_t hx, + const TensorDescriptor& cxDesc, + ConstData_t cx, + c_array_view dxDesc, + Data_t dx, + const TensorDescriptor& dhxDesc, + Data_t dhx, + const TensorDescriptor& dcxDesc, + Data_t dcx, + Data_t workSpace, + size_t workSpaceSize, + Data_t reserveSpace, + size_t reserveSpaceSize) const +{ + // Suppress warning + (void)y; + (void)yDesc; + (void)wDesc; - // active gate i, f, o - sp_size[2] = hy_h * 3; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + if(dx == nullptr || w == nullptr || dy == nullptr) + { + MIOPEN_THROW(miopenStatusBadParm); + } + if(dhyDesc.GetSize() != dcyDesc.GetSize() || dhyDesc.GetSize() != hxDesc.GetSize() || + dhyDesc.GetSize() != cxDesc.GetSize() || dhyDesc.GetSize() != dhxDesc.GetSize() || + dhyDesc.GetSize() != dcxDesc.GetSize()) + { + MIOPEN_THROW(miopenStatusBadParm); + } - sigDesc.Forward(handle, - &alpha, - sp_desc, - reserveSpace, - &beta, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride); +#if MIOPEN_BACKEND_HIP + HipEventPtr start = nullptr; + HipEventPtr stop = nullptr; + bool is_profiling = handle.IsProfilingEnabled(); - // active gate c - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + if(is_profiling) + { + handle.EnableProfiling(false); + RNNProfilingBegin(handle, start, stop); + } + try + { +#endif - tanhDesc.Forward(handle, - &alpha, - sp_desc, - reserveSpace, - &beta, - sp_desc, - reserveSpace, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len + - nLayers * batch_n * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + if(paddingMode == miopenRNNIONotPadded) + { + bool use_dropout = !float_equal(miopen::deref(dropoutDesc).dropout, 0); - // update cell state - alpha0 = 1; - alpha1 = 1; - beta_t = 1; + if((rnnMode == miopenRNNRELU || rnnMode == miopenRNNTANH) && !use_dropout) + { + RNNBackwardDataPackedTensorsRelu(handle, + seqLen, + dyDesc, + dy, + dhy, + w, + dxDesc, + dx, + dhxDesc, + dhx, + workSpace, + workSpaceSize, + reserveSpace, + reserveSpaceSize); + } + else + { + RNNBackwardDataPackedTensors(handle, + seqLen, + dyDesc, + dy, + dhy, + dcy, + w, + hx, + cx, + dxDesc, + dx, + dhxDesc, + dhx, + dcxDesc, + dcx, + workSpace, + workSpaceSize, + reserveSpace, + reserveSpaceSize); + } + } + else + { + Data_t packedDYIn = workSpace; + size_t packedDXSize, packedDYSize; + std::tie(packedDXSize, packedDYSize) = + RNNTensorPaddingConverter::GetTempPackedBuffersSpace(*this, dxDesc); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + 3 * static_cast(hy_h) + - static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + Data_t packedDXOut = + static_cast(reinterpret_cast(workSpace) + packedDYSize); - if(ti == 0) - { - if(cx != nullptr) - { - hx_size[1] = in_n.at(cur_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + auto shifted_workSpace = static_cast(reinterpret_cast(workSpace) + + (packedDYSize + packedDXSize)); + auto shifted_workSpace_size = workSpaceSize - (packedDYSize + packedDXSize); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - hx_desc, - cx, - &beta_t, - sp_desc, - reserveSpace, - offset + hy_h + static_cast(ri) * wei_len + - nLayers * batch_n * hy_stride, - hx_shift + ri * hy_n * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - if(ri == 1 && cx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) - { - hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + std::vector in_n(seqLen); - sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + for(int i = 0; i < seqLen; i++) + { + int batchval, batchvalout; + std::tie(batchval, std::ignore) = miopen::tien<2>(dxDesc[i].GetLengths()); + std::tie(batchvalout, std::ignore) = miopen::tien<2>(dyDesc[i].GetLengths()); + if(batchval != batchvalout) + { + MIOPEN_THROW(miopenStatusBadParm, + "Input batch length: " + std::to_string(batchval) + + ", Output batch length: " + std::to_string(batchvalout)); + } + in_n[i] = batchval; + } - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - hx_desc, - cx, - &beta_t, - sp_desc, - reserveSpace, - offset + hy_h + static_cast(ri) * wei_len + - static_cast(in_n.at(use_time)) * hy_stride + - nLayers * batch_n * hy_stride, - hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h + - static_cast(in_n.at(use_time)) * hy_stride, - true); - // Update time - profileRNNkernels(handle, 1, ctime); + RNNTensorPaddingConverter::ConvertTensorData( + handle, dyDesc[0], in_n, dy, packedDYIn, true); - sp_size[1] = in_n.at(cur_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - } + RNNDescriptor packedRnnDesc(*this); + packedRnnDesc.SetPaddingmode(miopenRNNIONotPadded); - if(in_n.at(use_time) > 0) - { - if(in_n.at(use_time) != in_n.at(cur_time)) - { - sp_size[1] = in_n.at(use_time); - sp_desc = miopen::TensorDescriptor( - wDesc.GetType(), sp_size, sp_stride); - } + packedRnnDesc.RNNBackwardDataPackedTensors(handle, + seqLen, + dyDesc, + packedDYIn, + dhy, + dcy, + w, + hx, + cx, + dxDesc, + packedDXOut, + dhxDesc, + dhx, + dcxDesc, + dcx, + shifted_workSpace, + shifted_workSpace_size, + reserveSpace, + reserveSpaceSize); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + hy_h + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - pretime_shift + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + RNNTensorPaddingConverter::ConvertTensorData( + handle, dxDesc[0], in_n, packedDXOut, dx, false); + } - if(in_n.at(use_time) != in_n.at(cur_time)) - { - sp_size[1] = in_n.at(cur_time); - sp_desc = miopen::TensorDescriptor( - wDesc.GetType(), sp_size, sp_stride); - } - } - } +#if MIOPEN_BACKEND_HIP + } + catch(...) + { + if(is_profiling) + handle.EnableProfiling(true); + throw; + } - // active cell state - tanhDesc.Forward(handle, - &alpha, - sp_desc, - reserveSpace, - &beta, - sp_desc, - reserveSpace, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h + - nLayers * batch_n * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + if(is_profiling) + { + float eventTime_mS = RNNProfilingEnd(handle, start, stop); + handle.EnableProfiling(true); + handle.ResetKernelTime(); + handle.AccumKernelTime(eventTime_mS); + } +#endif +} - // update hidden state - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + static_cast(bi) * wei_len + - static_cast(ri) * hy_h + - static_cast(nLayers) * batch_n * hy_stride, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - else if(rnnMode == miopenGRU) - { - // active z, r gate - sp_size[2] = 2 * hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); +void RNNDescriptor::RNNBackwardDataPackedTensorsRelu( + Handle& handle, + const int seqLen, + c_array_view dyDesc, + ConstData_t dy, + ConstData_t dhy, + ConstData_t w, + c_array_view dxDesc, + Data_t dx, + const TensorDescriptor& dhxDesc, + Data_t dhx, + Data_t workSpace, + size_t workSpaceSize, + Data_t reserveSpace, + size_t reserveSpaceSize) const +{ +#if MIOPEN_USE_GEMM + if(paddingMode != miopenRNNIONotPadded) + { + MIOPEN_THROW("Padded IO is not supported by this solver"); + } - sigDesc.Forward(handle, - &alpha, - sp_desc, - reserveSpace, - &beta, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + if(workSpaceSize < GetWorkspaceSize(handle, seqLen, dxDesc)) + { + MIOPEN_THROW("Workspace is required"); + } - // calculate c gate - sp_size[2] = hy_h; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + if(reserveSpaceSize < GetReserveSize(handle, seqLen, dxDesc)) + { + MIOPEN_THROW("Reservespace is required"); + } - CopyTensor(handle, - sp_desc, - reserveSpace, - sp_desc, - reserveSpace, - static_cast(offset) + 2 * hy_h + ri * wei_len, - static_cast(offset) + hid_off + ri * hy_h + - static_cast(nLayers) * batch_n * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + auto rnn_data_type = dhxDesc.GetType(); - alpha0 = 1; - alpha1 = 1; - beta_t = 0; + RnnBatches batches; + RnnBatches bacc_per_time; - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + hy_h + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + int input_size = dxDesc[0].GetLengths()[1]; + int hy_d = dhxDesc.GetLengths()[0]; + int max_batch = dhxDesc.GetLengths()[1]; + int hidden_size = dhxDesc.GetLengths()[2]; + int out_vec_size = dyDesc[0].GetLengths()[1]; + int bi = dirMode != 0u ? 2 : 1; - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + hid_off + static_cast(ri) * hy_h, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len); - // Update time - profileRNNkernels(handle, 1, ctime); + int in_stride = input_size; + int out_stride = out_vec_size; - // active c gate - tanhDesc.Forward(handle, - &alpha, - sp_desc, - reserveSpace, - &beta, - sp_desc, - reserveSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride); - // Update time - profileRNNkernels(handle, 1, ctime); + if(input_size <= 0 || hidden_size <= 0 || max_batch <= 0 || out_vec_size <= 0 || seqLen <= 0) + { + MIOPEN_THROW(miopenStatusBadParm); + } - // calculate hidden state - alpha0 = -1; - alpha1 = 1; - beta_t = 0; + float beta = 0; - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + auto workSpaceDataTypeSize = workSpaceSize / GetTypeSize(rnn_data_type); + auto workSpace_desc = + miopen::TensorDescriptor(rnn_data_type, + std::vector{1, 1, workSpaceDataTypeSize}, + std::vector{workSpaceDataTypeSize, workSpaceDataTypeSize, 1}); + SetTensor(handle, workSpace_desc, workSpace, &beta); - alpha0 = 1; - alpha1 = 1; - beta_t = 0; + if(dhx != nullptr) + { + int dhx_size = max_batch * hidden_size * hy_d; + auto dhx_desc = miopen::TensorDescriptor(rnn_data_type, + std::vector{1, 1, dhx_size}, + std::vector{dhx_size, dhx_size, 1}); + SetTensor(handle, dhx_desc, dhx, &beta); + } - OpTensor(handle, - miopenTensorOpAdd, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + 2 * static_cast(hy_h) + - static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - offset + hid_off + static_cast(ri) * hy_h, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); + int total_batch_size = 0; + for(int i = 0; i < seqLen; i++) + { + int batchval, inputvec, batchvalout, outputvec; + std::tie(batchval, inputvec) = miopen::tien<2>(dxDesc[i].GetLengths()); + std::tie(batchvalout, outputvec) = miopen::tien<2>(dyDesc[i].GetLengths()); + if(batchval != batchvalout) + { + MIOPEN_THROW(miopenStatusBadParm); + } + if(i == 0) + { + if(batchval <= 0) + { + MIOPEN_THROW(miopenStatusBadParm, "Input batch is ZERO!"); + } + } + else + { + if(batchval > batches.back() || batchval < 0) + { + MIOPEN_THROW(miopenStatusBadParm, + "Incorrect input batch size at time " + std::to_string(i) + + "! Batch size must not ascend!"); + } + } + + batches.push_back(batchval); + bacc_per_time.push_back(total_batch_size); + total_batch_size += batchval; + } + + if(out_vec_size != (bi * hidden_size)) + { + MIOPEN_THROW(miopenStatusBadParm, "Output size doesn't match hidden state size!"); + } - alpha0 = 1; - alpha1 = 1; - beta_t = 1; + if(inputMode == miopenRNNskip) + { + if(input_size != hidden_size) + { + MIOPEN_THROW(miopenStatusBadParm, + "The input tensor size must equal to the hidden " + "state size of the network in SKIP_INPUT mode!"); + } + input_size = 0; + } - if(ti == 0) - { - if(hx != nullptr) - { - hx_size[1] = in_n.at(cur_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + ActivationDescriptor activDesc; + if(rnnMode == miopenRNNRELU) + { + activDesc = {miopenActivationRELU, 1, 0, 1}; + } + else if(rnnMode == miopenRNNTANH) + { + activDesc = {miopenActivationTANH, 1, 1, 1}; + } - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - hx_desc, - hx, - &beta_t, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - hx_shift + ri * hy_n * hy_h, - offset + hid_off + static_cast(ri) * hy_h, - true); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - else - { - if(ri == 1 && hx != nullptr && in_n.at(cur_time) > in_n.at(use_time)) - { - hx_size[1] = in_n.at(cur_time) - in_n.at(use_time); - hx_size[2] = hy_h; - hx_desc = - miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + ReluWeightOffsets WeiBuf( + input_size, hidden_size, nLayers, biasMode * 2, bi, nHiddenTensorsPerLayer); - sp_size[1] = in_n.at(cur_time) - in_n.at(use_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + ReluReserveBufferOffsets RBuff(hidden_size, nLayers, total_batch_size, bi, workspaceScale); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - hx_desc, - hx, - &beta_t, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len + - static_cast(in_n.at(use_time)) * hy_stride + - static_cast(nLayers) * batch_n * hy_stride, - hx_shift + ri * hy_n * hy_h + in_n.at(use_time) * hy_h, - offset + hid_off + static_cast(ri) * hy_h + - static_cast(in_n.at(use_time)) * hy_stride, - true); - // Update time - profileRNNkernels(handle, 1, ctime); + auto get_HxBuff_offset = + [bi, hidden_size, max_batch](int layer_id, int batch_id, RnnDirection reverse) { + return (static_cast(hidden_size) * (max_batch)) * + (bi * layer_id + static_cast(reverse)) + + static_cast(hidden_size) * batch_id; + }; - sp_size[1] = in_n.at(cur_time); - sp_desc = - miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); - } + auto back_propagate_dy = + [&RBuff, out_vec_size, &handle, rnn_data_type, workSpace, dy, &WeiBuf, w](int numLayers, + int layer) { + // Propagate dy from output + // + if(layer == numLayers - 1) + { + const std::vector dy_output_size{ + 1, + static_cast(RBuff.batches_per_layer), + static_cast(out_vec_size)}; + + const std::vector dy_output_stride{ + static_cast(out_vec_size * RBuff.batches_per_layer), + static_cast(out_vec_size), + 1}; + + const std::vector dy_workspace_size{ + 1, + static_cast(RBuff.batches_per_layer), + static_cast(RBuff.gemm_write_size())}; + + const std::vector dy_workspace_stride{ + RBuff.layer_stride(), static_cast(RBuff.gemm_write_size()), 1}; + + auto dy_output_desc = + miopen::TensorDescriptor(rnn_data_type, dy_output_size, dy_output_stride); + auto dy_workspace_desc = + miopen::TensorDescriptor(rnn_data_type, dy_workspace_size, dy_workspace_stride); + + int dy_output_offset = 0; + int dy_workspace_offset = RBuff.layer_offset(layer); + + // dY(l,t) = dy(l,t); t = 1:seq_len - 1 + // + CopyTensor(handle, + dy_output_desc, + dy, + dy_workspace_desc, + workSpace, + dy_output_offset, + dy_workspace_offset); + } + // Propagate dy from previous layer + // + else + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + false, + RBuff.batches_per_layer, + RBuff.gemm_write_size(), + RBuff.gemm_write_size(), + RBuff.gemm_write_size(), + RBuff.gemm_write_size(), + RBuff.gemm_write_size(), + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + rnn_data_type, + false}; - if(in_n.at(use_time) > 0) - { - if(in_n.at(use_time) != in_n.at(cur_time)) - { - sp_size[1] = in_n.at(use_time); - sp_desc = miopen::TensorDescriptor( - wDesc.GetType(), sp_size, sp_stride); - } + int dy_prev_layer_offset = RBuff.layer_offset(layer + 1); + int dy_current_layer_offset = RBuff.layer_offset(layer); + + // dY(l,t) = dHt(l+1,t)/Why; t = 1:seq_len - 1 + // + miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + workSpace, + dy_prev_layer_offset, + w, + WeiBuf.input_weight_offset(layer + 1), + workSpace, + dy_current_layer_offset, + GemmBackend_t::rocblas); - OpTensor(handle, - miopenTensorOpMul, - &alpha0, - sp_desc, - reserveSpace, - &alpha1, - sp_desc, - reserveSpace, - &beta_t, - sp_desc, - reserveSpace, - offset + static_cast(ri) * wei_len + - static_cast(nLayers) * batch_n * hy_stride, - pretime_shift + hid_off + ri * hy_h, - offset + hid_off + static_cast(ri) * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); } } } + }; - bacc += in_n.at(ti); - } + auto back_propagate_dhy_output = [&RBuff, + &handle, + hidden_size, + &batches, + &bacc_per_time, + rnn_data_type, + dhy, + workSpace, + seqLen, + &get_HxBuff_offset]( + int layer, int time, RnnDirection direction) { + if(dhy == nullptr) + return; - // update hy, cy - if(hy != nullptr || (rnnMode == miopenLSTM && cy != nullptr)) - { - hx_size[2] = hy_h; - sp_size[2] = hy_h; + const int start_time = direction == RnnDirection::Forward ? 0 : seqLen - time - 1; - bacc = batch_n; - baccbi = 0; - for(int ti = seqLen - 1; ti >= 0; ti--) - { - bacc -= in_n.at(ti); - for(int ri = 0; ri < bi; ri++) - { - int cur_time = ri == 0 ? ti : seqLen - 1 - ti; - int cur_batch = ri == 0 ? bacc : baccbi; - int use_batch = 0; + float alpha0 = 1; + float alpha1 = 1; + float beta_t = 0; - if(ti < seqLen - 1) - { - int use_time = ri == 0 ? ti + 1 : seqLen - 2 - ti; - use_batch = in_n.at(use_time); - } + std::vector dhy_out_stride{ + batches.at(start_time, RnnDirection::Forward) * hidden_size, hidden_size, 1}; - if(in_n.at(cur_time) > use_batch) - { - offset = hid_shift + cur_batch * hy_stride; + std::vector dhy_workspace_stride{ + static_cast(RBuff.layer_stride()), RBuff.gemm_write_size(), 1}; - sp_size[1] = in_n.at(cur_time) - use_batch; - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + std::vector dhy_out_size{1, batches.at(time, direction), hidden_size}; + std::vector dhy_workspace_size{1, batches.at(time, direction), hidden_size}; - hx_size[1] = sp_size[1]; - hx_desc = miopen::TensorDescriptor(wDesc.GetType(), hx_size, hx_stride); + auto dhy_out_desc = miopen::TensorDescriptor(rnn_data_type, dhy_out_size, dhy_out_stride); - if(hy != nullptr) - { - CopyTensor(handle, - sp_desc, - reserveSpace, - hx_desc, - hy, - static_cast(offset) + hid_off + ri * hy_h + - use_batch * hy_stride, - hx_shift + ri * hy_n * hy_h + use_batch * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } + auto dhy_workspace_desc = + miopen::TensorDescriptor(rnn_data_type, dhy_workspace_size, dhy_workspace_stride); - if(rnnMode == miopenLSTM && cy != nullptr) - { - CopyTensor(handle, - sp_desc, - reserveSpace, - hx_desc, - cy, - static_cast(offset) + bi * wei_len + ri * hy_h + - use_batch * hy_stride, - hx_shift + ri * hy_n * hy_h + use_batch * hy_h); - // Update time - profileRNNkernels(handle, 1, ctime); - } - } - } - baccbi += in_n.at(seqLen - 1 - ti); - } - } - } + auto dhy_out_offset = get_HxBuff_offset(layer, 0, direction); + auto dhy_workspace_offset = + RBuff.gemm_write_offset(layer, bacc_per_time.at(time, direction), direction); - // output - prelayer_shift = (static_cast(nLayers) - 1) * batch_n * hy_stride + hid_off; + // dHt(l,seq_len - 1) = dY(l,seq_len - 1) + dHy(seq_len - 1) + // + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + dhy_out_desc, + dhy, + &alpha1, + dhy_workspace_desc, + workSpace, + &beta_t, + dhy_workspace_desc, + workSpace, + dhy_out_offset, + dhy_workspace_offset, + dhy_workspace_offset); + }; - sp_size[1] = batch_n; - sp_size[2] = hy_h * bi; - y_size[1] = batch_n; - y_size[2] = out_h; - y_desc = miopen::TensorDescriptor(wDesc.GetType(), y_size, y_stride); - sp_desc = miopen::TensorDescriptor(wDesc.GetType(), sp_size, sp_stride); + auto back_propagate_dhy_prev = [&RBuff, + &handle, + hidden_size, + &batches, + &bacc_per_time, + rnn_data_type, + dhy, + workSpace, + &get_HxBuff_offset, + WeiBuf, + w, + max_batch](int layer, int time, RnnDirection direction) { + auto dbatches = batches.at(time, direction) - batches.next(time, direction); - CopyTensor(handle, sp_desc, reserveSpace, y_desc, y, prelayer_shift, 0); - // Update time - profileRNNkernels(handle, 2, ctime); + if(direction == RnnDirection::Forward && dhy != nullptr && dbatches > 0) + { + std::vector dhy_stride{max_batch * hidden_size, hidden_size, 1}; -#if MIOPEN_BACKEND_HIP - if(is_profiling) - { - float eventTime_mS = RNNProfilingEnd(handle, start, stop); - handle.EnableProfiling(true); - handle.ResetKernelTime(); - handle.AccumKernelTime(eventTime_mS); - } -#endif + std::vector dht_tensor_stride{ + static_cast(RBuff.layer_stride()), RBuff.gemm_write_size(), 1}; -#else - (void)handle; - (void)seqLen; - (void)xDesc; - (void)x; - (void)w; - (void)hx; - (void)cx; - (void)y; - (void)hyDesc; - (void)hy; - (void)yDesc; - (void)cy; - (void)hxDesc; - (void)wDesc; - (void)reserveSpace; - (void)reserveSpaceSize; - MIOPEN_THROW("GEMM is not supported"); -#endif -}; + std::vector dhy_size{1, dbatches, hidden_size}; + + std::vector dht_tensor_size{1, dbatches, hidden_size}; + + float alpha0 = 1; + float alpha1 = 1; + float beta_t = 0; + + auto dhy_tensor_desc = miopen::TensorDescriptor(rnn_data_type, dhy_size, dhy_stride); + auto dht_tensor_desc = + miopen::TensorDescriptor(rnn_data_type, dht_tensor_size, dht_tensor_stride); + + auto dhy_offset = get_HxBuff_offset(layer, batches.next(time, direction), direction); + auto dht_tensor_offset = RBuff.gemm_write_offset(layer, + bacc_per_time.at(time, direction) + + batches.next(time, direction), + direction); + // dHt(t,l) = dY(t,l) + dHy(t,l) for relative batches = batches(t+1):batches(t) + // + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + dhy_tensor_desc, + dhy, + &alpha1, + dht_tensor_desc, + workSpace, + &beta_t, + dht_tensor_desc, + workSpace, + dhy_offset, + dht_tensor_offset, + dht_tensor_offset); + } -void RNNDescriptor::RNNBackwardData(Handle& handle, - const int seqLen, - c_array_view yDesc, - ConstData_t y, - c_array_view dyDesc, - ConstData_t dy, - const TensorDescriptor& dhyDesc, - ConstData_t dhy, - const TensorDescriptor& dcyDesc, - ConstData_t dcy, - const TensorDescriptor& wDesc, - ConstData_t w, - const TensorDescriptor& hxDesc, - ConstData_t hx, - const TensorDescriptor& cxDesc, - ConstData_t cx, - c_array_view dxDesc, - Data_t dx, - const TensorDescriptor& dhxDesc, - Data_t dhx, - const TensorDescriptor& dcxDesc, - Data_t dcx, - Data_t workSpace, - size_t workSpaceSize, - Data_t reserveSpace, - size_t reserveSpaceSize) const -{ - // Suppress warning - (void)y; - (void)yDesc; - (void)wDesc; + int dht_batch_size = direction == RnnDirection::Forward ? batches.next(time, direction) + : batches.at(time, direction); - if(dx == nullptr || w == nullptr || dy == nullptr) - { - MIOPEN_THROW(miopenStatusBadParm); - } - if(dhyDesc.GetSize() != dcyDesc.GetSize() || dhyDesc.GetSize() != hxDesc.GetSize() || - dhyDesc.GetSize() != cxDesc.GetSize() || dhyDesc.GetSize() != dhxDesc.GetSize() || - dhyDesc.GetSize() != dcxDesc.GetSize()) - { - MIOPEN_THROW(miopenStatusBadParm); - } + if(dht_batch_size <= 0) + return; -#if MIOPEN_BACKEND_HIP - HipEventPtr start = nullptr; - HipEventPtr stop = nullptr; - bool is_profiling = handle.IsProfilingEnabled(); + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + false, + dht_batch_size, + hidden_size, + hidden_size, + RBuff.gemm_write_size(), + hidden_size, + RBuff.gemm_write_size(), + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + rnn_data_type, + false}; - if(is_profiling) - { - handle.EnableProfiling(false); - RNNProfilingBegin(handle, start, stop); - } - try - { -#endif + // dHt(l,t) = dY(t,l) + dHt^(t+1)/Whh for relative batches = 0:batches(t+1) + // + int dht_next_deactivated_offset = + RBuff.gemm_write_offset(layer, bacc_per_time.next(time, direction), direction); + int dht_offset = + RBuff.gemm_write_offset(layer, bacc_per_time.at(time, direction), direction); + miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + workSpace, + dht_next_deactivated_offset, + w, + WeiBuf.hidden_weight_offset(layer, direction), + workSpace, + dht_offset, + GemmBackend_t::rocblas); - if(paddingMode == miopenRNNIONotPadded) + if(gemm_status != miopenStatusSuccess) { - RNNBackwardDataPackedTensors(handle, - seqLen, - dyDesc, - dy, - dhy, - dcy, - w, - hx, - cx, - dxDesc, - dx, - dhxDesc, - dhx, - dcxDesc, - dcx, - workSpace, - workSpaceSize, - reserveSpace, - reserveSpaceSize); + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } + }; + + auto back_propagate_dhy_time = [&RBuff, + &handle, + hidden_size, + seqLen, + &batches, + &bacc_per_time, + rnn_data_type, + workSpace, + reserveSpace, + &activDesc, + &back_propagate_dhy_output, + &back_propagate_dhy_prev]( + int layer, int time, RnnDirection direction) { + if(time == seqLen - 1) + { + back_propagate_dhy_output(layer, time, direction); } else { - Data_t packedDYIn = workSpace; - size_t packedDXSize, packedDYSize; - std::tie(packedDXSize, packedDYSize) = - RNNTensorPaddingConverter::GetTempPackedBuffersSpace(*this, dxDesc); + back_propagate_dhy_prev(layer, time, direction); + } - Data_t packedDXOut = - static_cast(reinterpret_cast(workSpace) + packedDYSize); + std::vector dht_tensor_stride{ + RBuff.batches_per_layer * RBuff.gemm_write_size(), RBuff.gemm_write_size(), 1}; + std::vector dht_tensor_size{1, batches.at(time, direction), hidden_size}; + + auto dht_desc = miopen::TensorDescriptor(rnn_data_type, dht_tensor_size, dht_tensor_stride); + + float alpha = 1, beta = 0; + // dHt^(l,t) = @^-1(dHt(l,t)) + // + activDesc.Backward( + handle, + &alpha, + dht_desc, + reserveSpace, + dht_desc, + workSpace, + dht_desc, + reserveSpace, + &beta, + dht_desc, + workSpace, + RBuff.hidden_offset(layer, bacc_per_time.at(time, direction), direction), + RBuff.gemm_write_offset(layer, bacc_per_time.at(time, direction), direction), + RBuff.gemm_write_offset(layer, bacc_per_time.at(time, direction), direction), + RBuff.gemm_write_offset(layer, bacc_per_time.at(time, direction), direction)); + }; - auto shifted_workSpace = static_cast(reinterpret_cast(workSpace) + - (packedDYSize + packedDXSize)); - auto shifted_workSpace_size = workSpaceSize - (packedDYSize + packedDXSize); + auto back_propagate_dhy = [this, seqLen, &back_propagate_dhy_time](int layer) { + for(int time = seqLen - 1; time >= 0; time--) + { + back_propagate_dhy_time(layer, time, RnnDirection::Forward); + if(dirMode == 0u) + continue; + back_propagate_dhy_time(layer, time, RnnDirection::Backward); + } + }; - std::vector in_n(seqLen); + auto forward_propagate_dhx_prev = [&RBuff, + &WeiBuf, + rnn_data_type, + hidden_size, + &batches, + &bacc_per_time, + &handle, + w, + dhx, + &get_HxBuff_offset, + workSpace](int layer, int time, RnnDirection direction) { + int dbatches = time == 0 ? batches.at(time, direction) + : batches.at(time, direction) - batches.prev(time, direction); + + if(dbatches <= 0) + return; - for(int i = 0; i < seqLen; i++) + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + false, + dbatches, + hidden_size, + hidden_size, + RBuff.gemm_write_size(), + hidden_size, + hidden_size, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 1, // beta + rnn_data_type, + false}; + + int dhx_batch = time == 0 ? 0 : batches.prev(time, direction); + + int dht_batch = time == 0 ? bacc_per_time.at(time, direction) + : bacc_per_time.prev(time, direction) - dbatches; + + int dhx_offset = get_HxBuff_offset(layer, dhx_batch, direction); + int dht_prev_offset = RBuff.gemm_write_offset(layer, dht_batch, direction); + + // dhx(l,t) = dHt(l,t-1)/Whh for relative batches = batches(t+1):batches(t) + // + miopenStatus_t gemm_status = CallGemm(handle, + gemm_desc, + workSpace, + dht_prev_offset, + w, + WeiBuf.hidden_weight_offset(layer, direction), + dhx, + dhx_offset, + GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) { - int batchval, batchvalout; - std::tie(batchval, std::ignore) = miopen::tien<2>(dxDesc[i].GetLengths()); - std::tie(batchvalout, std::ignore) = miopen::tien<2>(dyDesc[i].GetLengths()); - if(batchval != batchvalout) - { - MIOPEN_THROW(miopenStatusBadParm, - "Input batch length: " + std::to_string(batchval) + - ", Output batch length: " + std::to_string(batchvalout)); - } - in_n[i] = batchval; + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); } + } + }; - RNNTensorPaddingConverter::ConvertTensorData( - handle, dyDesc[0], in_n, dy, packedDYIn, true); + auto forward_propagate_dhx = [this, seqLen, &forward_propagate_dhx_prev, dhx](int layer) { + if(dhx == nullptr) + return; - RNNDescriptor packedRnnDesc(*this); - packedRnnDesc.SetPaddingmode(miopenRNNIONotPadded); + for(int time = 0; time < seqLen; time++) + { + forward_propagate_dhx_prev(layer, time, RnnDirection::Forward); - packedRnnDesc.RNNBackwardDataPackedTensors(handle, - seqLen, - dyDesc, - packedDYIn, - dhy, - dcy, - w, - hx, - cx, - dxDesc, - packedDXOut, - dhxDesc, - dhx, - dcxDesc, - dcx, - shifted_workSpace, - shifted_workSpace_size, - reserveSpace, - reserveSpaceSize); + if(dirMode == 0u) + continue; - RNNTensorPaddingConverter::ConvertTensorData( - handle, dxDesc[0], in_n, packedDXOut, dx, false); + forward_propagate_dhx_prev(layer, time, RnnDirection::Backward); } + }; -#if MIOPEN_BACKEND_HIP - } - catch(...) + for(int li = static_cast(nLayers) - 1; li >= 0; li--) { - if(is_profiling) - handle.EnableProfiling(true); - throw; + back_propagate_dy(nLayers, li); + back_propagate_dhy(li); + forward_propagate_dhx(li); } - if(is_profiling) + int hy_stride = hidden_size * bi * static_cast(workspaceScale); + + if(inputMode == miopenRNNskip) { - float eventTime_mS = RNNProfilingEnd(handle, start, stop); - handle.EnableProfiling(true); - handle.ResetKernelTime(); - handle.AccumKernelTime(eventTime_mS); + auto workspace_desc = + miopen::TensorDescriptor(rnn_data_type, + {1, total_batch_size, hidden_size}, + {static_cast(total_batch_size) * out_stride, out_stride, 1}); + auto dx_desc = miopen::TensorDescriptor(rnn_data_type, + {1, total_batch_size, hidden_size}, + {static_cast(total_batch_size) * in_stride, in_stride, 1}); + + float alpha0 = 1; + float alpha1 = 1; + float beta_t = 0; + + for(int gi = 0; gi < nHiddenTensorsPerLayer * bi; gi++) + { + OpTensor(handle, + miopenTensorOpAdd, + &alpha0, + workspace_desc, + workSpace, + &alpha1, + dx_desc, + dx, + &beta_t, + dx_desc, + dx, + static_cast(gi) * hidden_size, + 0, + 0); + } + } + else + { + miopen::GemmDescriptor gemm_desc = GemmDescriptor{false, + false, + false, + total_batch_size, + input_size, + RBuff.gemm_write_size(), + hy_stride, + in_stride, + in_stride, + 1, // batch count + 0, // Stride A + 0, // Stride B + 0, // Stride C + 1, // alpha + 0, // beta + rnn_data_type, + false}; + miopenStatus_t gemm_status = + CallGemm(handle, gemm_desc, workSpace, 0, w, 0, dx, 0, GemmBackend_t::rocblas); + + if(gemm_status != miopenStatusSuccess) + { + if(gemm_status == miopenStatusNotImplemented) + { + MIOPEN_LOG_E("GEMM not implemented"); + } + else + { + MIOPEN_LOG_E("GEMM failed"); + } + } } + +#else + (void)handle; + (void)seqLen; + (void)dhy; + (void)dcy; + (void)dyDesc; + (void)dy; + (void)w; + (void)hx; + (void)cx; + (void)dxDesc; + (void)dx; + (void)dhxDesc; + (void)dhx; + (void)dcxDesc; + (void)dcx; + (void)workSpace; + (void)workSpaceSize; + (void)reserveSpace; + (void)reserveSpaceSize; + MIOPEN_THROW("GEMM is not supported"); #endif -} +}; void RNNDescriptor::RNNBackwardDataPackedTensors( Handle& handle, @@ -4053,6 +5357,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( x_stride[1] = in_stride; y_stride[0] = batch_n * out_stride; y_stride[1] = out_stride; + if(dhx != nullptr || (rnnMode == miopenLSTM && dcx != nullptr)) { hx_size[2] = hy_d * hy_n * hy_h; @@ -4072,6 +5377,7 @@ void RNNDescriptor::RNNBackwardDataPackedTensors( profileRNNkernels(handle, 1, ctime); } } + hx_stride[0] = in_n.at(0) * uni_stride; hx_stride[1] = uni_stride; @@ -6388,5 +7694,4 @@ void RNNDescriptor::RNNBackwardWeightsPackedTensors( MIOPEN_THROW("GEMM is not supported"); #endif }; - } // namespace miopen