diff --git a/src/onnx/parse_lstm.cpp b/src/onnx/parse_lstm.cpp index 352d0977c5b..871f2593ecb 100644 --- a/src/onnx/parse_lstm.cpp +++ b/src/onnx/parse_lstm.cpp @@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector& actv } } +void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector& args) +{ + std::vector perm{1, 0, 2}; + args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]); + + if(args.size() >= 6 and not args[5]->is_undefined()) + { + args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]); + } + + if(args.size() >= 7 and not args[6]->is_undefined()) + { + args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]); + } +} + +void lstm_transpose_outputs(onnx_parser::node_info& info, + instruction_ref& hidden_states, + instruction_ref& last_output, + instruction_ref& last_cell_output) +{ + std::vector perm_hs{2, 0, 1, 3}; + hidden_states = + info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states); + std::vector perm_last{1, 0, 2}; + last_output = + info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output); + last_cell_output = + info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output); +} + struct parse_lstm : op_parser { std::vector operators() const { return {{"LSTM"}}; } @@ -202,6 +233,12 @@ struct parse_lstm : op_parser input_forget = parser.parse_value(info.attributes.at("input_forget")).at(); } + int layout = 0; + if(contains(info.attributes, "layout")) + { + layout = parser.parse_value(info.attributes.at("layout")).at(); + } + // append undefined opeator to make 6 arguments if(args.size() < 8) { @@ -209,6 +246,11 @@ struct parse_lstm : op_parser args.insert(args.end(), 8 - args.size(), ins); } + if(layout != 0) + { + lstm_transpose_inputs(info, args); + } + // first output for concatenation of hidden states auto hidden_states = info.add_instruction(make_op("lstm", {{"hidden_size", hidden_size}, @@ -224,6 +266,11 @@ struct parse_lstm : op_parser auto last_cell_output = info.add_instruction(make_op("rnn_last_cell_output"), hidden_states); + if(layout != 0) + { + lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output); + } + return {hidden_states, last_output, last_cell_output}; } }; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index dd1f90d755f..24b9a8267bd 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -4484,6 +4484,177 @@ def lrn_test(): return ([node], [x], [y]) +@onnx_test() +def lstm_bi_layout_cell_test(): + seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10]) + r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160]) + seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3]) + h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20]) + c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20]) + pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60]) + + cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT, + [3, 2, 20]) + + node = onnx.helper.make_node( + 'LSTM', + inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'], + outputs=['', '', 'cellout'], + activations=['sigmoid', 'tanh', 'tanh'], + clip=0, + direction='bidirectional', + hidden_size=20, + input_forget=1, + layout=1) + + return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout]) + + +@onnx_test() +def lstm_bi_layout_last_test(): + seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [2, 80, 10]) + r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [2, 80, 20]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2, 160]) + seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3]) + h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 2, 20]) + c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 2, 20]) + pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [2, 60]) + + hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 2, 20]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [3, 2, 20]) + + node = onnx.helper.make_node( + 'LSTM', + inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'], + outputs=['hs', 'output'], + activations=['sigmoid', 'tanh', 'tanh'], + clip=0, + direction='bidirectional', + hidden_size=20, + input_forget=1, + layout=1) + + return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output]) + + +@onnx_test() +def lstm_f_layout_hs_test(): + seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10]) + r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160]) + seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3]) + h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20]) + c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20]) + pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60]) + + hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [3, 1, 20]) + + node = onnx.helper.make_node( + 'LSTM', + inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'], + outputs=['hs', 'output'], + activations=['sigmoid', 'tanh', 'tanh'], + clip=0, + direction='forward', + hidden_size=20, + input_forget=1, + layout=1) + + return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs, output]) + + +@onnx_test() +def lstm_f_layout_cell_test(): + seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10]) + r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160]) + seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3]) + h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20]) + c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20]) + pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60]) + + cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT, + [3, 1, 20]) + + node = onnx.helper.make_node( + 'LSTM', + inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'], + outputs=['', '', 'cellout'], + activations=['sigmoid', 'tanh', 'tanh'], + clip=0, + direction='forward', + hidden_size=20, + input_forget=1, + layout=1) + + return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [cellout]) + + +@onnx_test() +def lstm_r_layout_test(): + seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10]) + r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160]) + seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3]) + h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20]) + c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20]) + pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60]) + + hs = helper.make_tensor_value_info('hs', TensorProto.FLOAT, [3, 5, 1, 20]) + + node = onnx.helper.make_node( + 'LSTM', + inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'], + outputs=['hs'], + activations=['sigmoid', 'tanh', 'tanh'], + clip=0, + direction='reverse', + hidden_size=20, + input_forget=1, + layout=1) + + return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [hs]) + + +@onnx_test() +def lstm_r_layout_hs_cell_test(): + seq = helper.make_tensor_value_info('seq', TensorProto.FLOAT, [3, 5, 10]) + w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [1, 80, 10]) + r = helper.make_tensor_value_info('r', TensorProto.FLOAT, [1, 80, 20]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 160]) + seq_len = helper.make_tensor_value_info('seq_len', TensorProto.INT32, [3]) + h0 = helper.make_tensor_value_info('h0', TensorProto.FLOAT, [3, 1, 20]) + c0 = helper.make_tensor_value_info('c0', TensorProto.FLOAT, [3, 1, 20]) + pph = helper.make_tensor_value_info('pph', TensorProto.FLOAT, [1, 60]) + + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, + [3, 1, 20]) + cellout = helper.make_tensor_value_info('cellout', TensorProto.FLOAT, + [3, 1, 20]) + + node = onnx.helper.make_node( + 'LSTM', + inputs=['seq', 'w', 'r', 'bias', 'seq_len', 'h0', 'c0', 'pph'], + outputs=['', 'output', 'cellout'], + activations=['sigmoid', 'tanh', 'tanh'], + clip=0, + direction='reverse', + hidden_size=20, + input_forget=1, + layout=1) + + return ([node], [seq, w, r, bias, seq_len, h0, c0, pph], [output, cellout]) + + @onnx_test() def matmul_bmbm_test(): m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 6, 7]) diff --git a/test/onnx/lstm_bi_layout_cell_test.onnx b/test/onnx/lstm_bi_layout_cell_test.onnx new file mode 100644 index 00000000000..cce2083c286 Binary files /dev/null and b/test/onnx/lstm_bi_layout_cell_test.onnx differ diff --git a/test/onnx/lstm_bi_layout_last_test.onnx b/test/onnx/lstm_bi_layout_last_test.onnx new file mode 100644 index 00000000000..b18bbba76fb Binary files /dev/null and b/test/onnx/lstm_bi_layout_last_test.onnx differ diff --git a/test/onnx/lstm_f_layout_cell_test.onnx b/test/onnx/lstm_f_layout_cell_test.onnx new file mode 100644 index 00000000000..134447e8c9d Binary files /dev/null and b/test/onnx/lstm_f_layout_cell_test.onnx differ diff --git a/test/onnx/lstm_f_layout_hs_test.onnx b/test/onnx/lstm_f_layout_hs_test.onnx new file mode 100644 index 00000000000..1626925a31b Binary files /dev/null and b/test/onnx/lstm_f_layout_hs_test.onnx differ diff --git a/test/onnx/lstm_r_layout_hs_cell_test.onnx b/test/onnx/lstm_r_layout_hs_cell_test.onnx new file mode 100644 index 00000000000..2cf85d4d120 Binary files /dev/null and b/test/onnx/lstm_r_layout_hs_cell_test.onnx differ diff --git a/test/onnx/lstm_r_layout_test.onnx b/test/onnx/lstm_r_layout_test.onnx new file mode 100644 index 00000000000..53442b11641 Binary files /dev/null and b/test/onnx/lstm_r_layout_test.onnx differ diff --git a/test/onnx/onnx_rnn_test.cpp b/test/onnx/onnx_rnn_test.cpp index 5ba978ae617..55041285175 100644 --- a/test/onnx/onnx_rnn_test.cpp +++ b/test/onnx/onnx_rnn_test.cpp @@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward) } } +TEST_CASE(lstm_forward_layout) +{ + std::size_t sl = 5; // sequence len + std::size_t bs = 3; // batch size + std::size_t hs = 20; // hidden size + std::size_t is = 10; // input size + std::size_t nd = 1; // num directions + float clip = 0.0f; + int input_forget = 1; + migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}}; + migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; + migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; + migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; + + // 8 args, hs and last output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + std::vector perm_hid{2, 0, 1, 3}; + out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), + out_hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + + auto prog = optimize_onnx("lstm_f_layout_hs_test.onnx"); + + EXPECT(p == prog); + } + // 8 args, cell output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); + auto prog = optimize_onnx("lstm_f_layout_cell_test.onnx"); + + EXPECT(p == prog); + } +} + // activation functions TEST_CASE(lstm_forward_actv_func) { @@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse) } } +TEST_CASE(lstm_reverse_layout) +{ + std::size_t sl = 5; // sequence len + std::size_t bs = 3; // batch size + std::size_t hs = 20; // hidden size + std::size_t is = 10; // input size + std::size_t nd = 1; // num directions + float clip = 0.0f; + int input_forget = 1; + migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}}; + migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; + migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; + migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; + + // 8 args, hs output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + std::vector perm_hid{2, 0, 1, 3}; + out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), + out_hs); + auto prog = optimize_onnx("lstm_r_layout_test.onnx"); + + EXPECT(p == prog); + } + + // 8 args, last and cell output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + last_output = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), + last_output); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); + + auto prog = optimize_onnx("lstm_r_layout_hs_cell_test.onnx"); + + EXPECT(p == prog); + } +} + TEST_CASE(lstm_bidirectional) { std::size_t sl = 5; // sequence len @@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional) } } +TEST_CASE(lstm_bidirectional_layout) +{ + std::size_t sl = 5; // sequence len + std::size_t bs = 3; // batch size + std::size_t hs = 20; // hidden size + std::size_t is = 10; // input size + std::size_t nd = 2; // num directions + float clip = 0.0f; + int input_forget = 1; + migraphx::shape seq_shape{migraphx::shape::float_type, {bs, sl, is}}; + migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}}; + migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}}; + migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {bs, nd, hs}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}}; + // 0 activation function + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + std::vector perm_hid{2, 0, 1, 3}; + out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), + out_hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + auto prog = optimize_onnx("lstm_bi_layout_last_test.onnx"); + + EXPECT(p == prog); + } + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_parameter("seq", seq_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", bias_shape); + auto seq_len = mm->add_parameter("seq_len", sl_shape); + auto ih = mm->add_parameter("h0", ih_shape); + auto ic = mm->add_parameter("c0", ih_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hs}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh"), + migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", input_forget}}), + seq, + w, + r, + bias, + seq_len, + ih, + ic, + pph); + auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); + auto prog = optimize_onnx("lstm_bi_layout_cell_test.onnx"); + + EXPECT(p == prog); + } +} + TEST_CASE(lstm_bi_actv_funcs) { std::size_t sl = 5; // sequence len diff --git a/test/py/onnx_backend_test.py b/test/py/onnx_backend_test.py index cc69799bff8..f51e4ce45f4 100644 --- a/test/py/onnx_backend_test.py +++ b/test/py/onnx_backend_test.py @@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test): # fails # from OnnxBackendNodeModelTest backend_test.exclude(r'test_gru_batchwise_cpu') - backend_test.exclude(r'test_lstm_batchwise_cpu') backend_test.exclude(r'test_simple_rnn_batchwise_cpu') # from OnnxBackendPyTorchConvertedModelTest backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') diff --git a/test/ref/rnn_ops.cpp b/test/ref/rnn_ops.cpp index 7677508a240..f6341fff67d 100644 --- a/test/ref/rnn_ops.cpp +++ b/test/ref/rnn_ops.cpp @@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward) } } +TEST_CASE(lstm_forward_layout) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260}; + + std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, + -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, + 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, + 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, + -0.3025, 0.3637, -0.3181, -0.4655}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366, + 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332, + 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331, + -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, + -1.9004, + 0.3337, + 0.5741, + 0.5671, + 0.0458, + 0.4514, + -0.8968, + -0.9201, + 0.1962, + 0.5771, + -0.5332}; + + std::vector ic_data{0.9569, + -0.5981, + 1.1312, + 1.0945, + 1.1055, + -0.1212, + -0.9097, + 0.7831, + -1.6991, + -1.9498, + -1.2567, + -0.4114}; + + std::vector pph_data{1.84369764, + 0.68413646, + -0.44892886, + -1.50904413, + 0.3860796, + -0.52186625, + 1.08474445, + -1.80867321, + 1.32594529, + 0.4336262, + -0.83699064, + 0.49162736}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // forward, hidden state concatenation as output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + und); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + p.compile(migraphx::make_target("ref")); + + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.0417273, -0.272355, 0.206765, 0.223879, 0.0742487, -0.0800085, 0.259897, + 0.0670196, -0.00532985, 0.0440265, 0.29654, -0.0463156, -0.0847427, 0.0874114, + 0.304256, -0.0585745, 0.138193, -0.0322939, -0.0891815, 0.15773, 0.184266, + 0.0610048, -0.138041, 0.0963885, 0.0498799, 0.125772, 0.0533032, -0.131413, + -0.0223018, 0.131113, 0.135643, -0.056620, 0.19139, -0.127708, -0.409371, + -0.136186, 0.0213755, -0.146027, -0.0324509, -0.0620429, 0.0988431, -0.018085, + -0.159434, 0.030266, 0.142701, 0.0342236, -0.198664, 0.0702607}; + EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); + } + + // forward, last_output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + und); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + p.compile(migraphx::make_target("ref")); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{-0.0847427, + 0.0874114, + 0.304256, + -0.0585745, + -0.0223018, + 0.131113, + 0.135643, + -0.0566208, + 0.142701, + 0.0342236, + -0.198664, + 0.0702607}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // forward, last_cell_output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + und); + auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output); + p.compile(migraphx::make_target("ref")); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{-0.111454, + 0.247794, + 0.471087, + -0.220574, + -0.048196, + 0.263184, + 0.283258, + -0.14882, + 0.605585, + 0.078598, + -0.64457, + 0.119811}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } +} + TEST_CASE(lstm_forward_more) { std::size_t batch_size = 3; @@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more) } } -TEST_CASE(lstm_reverse) +TEST_CASE(lstm_forward_more_layout) { std::size_t batch_size = 3; std::size_t seq_len = 4; @@ -3527,32 +3785,668 @@ TEST_CASE(lstm_reverse) std::size_t input_size = 3; std::size_t num_dirct = 1; std::vector w_data{ - -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, - -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, - -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, - 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, - -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; std::vector r_data{ - -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, - 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, - -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, - 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, - 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, - -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, - 0.3987, -0.1687, -0.0032, -0.1038}; - - std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, - -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, - 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, - -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, - -0.4386, 0.4208, 0.0717, 0.3789}; + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260}; + + std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, + -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, + 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, + 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, + -0.3025, 0.3637, -0.3181, -0.4655}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366, + 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332, + 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331, + -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, + -1.9004, + 0.3337, + 0.5741, + 0.5671, + 0.0458, + 0.4514, + -0.8968, + -0.9201, + 0.1962, + 0.5771, + -0.5332}; + + std::vector ic_data{0.9569, + -0.5981, + 1.1312, + 1.0945, + 1.1055, + -0.1212, + -0.9097, + 0.7831, + -1.6991, + -1.9498, + -1.2567, + -0.4114}; + + std::vector pph_data{1.84369764, + 0.68413646, + -0.44892886, + -1.50904413, + 0.3860796, + -0.52186625, + 1.08474445, + -1.80867321, + 1.32594529, + 0.4336262, + -0.83699064, + 0.49162736}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // forward, 3 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + p.compile(migraphx::make_target("ref")); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, -0.0786602, -0.0613048, 0.179592, + -0.071286, -0.102509, -0.0372696, 0.252296, -0.144544, -0.165194, -0.0372928, + 0.273786, -0.100877, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.074206, + 0.0124086, -0.139544, 0.108016, 0.00496085, 0.0662588, -0.048577, -0.187329, + -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.101585, 0.0687269, -0.161725, + -0.25617, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.0855831, -0.0171894, + -0.140202, 0.0828391, 0.136898, 0.00160891, -0.184812, 0.147774}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // forward, 8 args + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + p.compile(migraphx::make_target("ref")); + + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, 0.186991, -0.0624168, 0.205513, + 0.0836373, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.058052, 0.0795391, + 0.266617, -0.0128746, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.421857, + 0.0459771, -0.144955, 0.0720673, 0.218258, 0.0944405, 0.0431211, -0.132394, + 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.022618, -0.121195, -0.4065, + -0.252054, -0.0300906, -0.0890598, -0.135266, -0.0413375, 0.103489, 0.0142918, + -0.123408, 0.0401075, 0.187761, 0.0501726, -0.121584, 0.0606723}; + + EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); + } + + // forward, last_output as program output, sequence length shorter + // than max_seq_len + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + und); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + p.compile(migraphx::make_target("ref")); + + auto last_hs = p.eval({}).back(); + std::vector output_data; + last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{-0.0847427, + 0.0874114, + 0.304256, + -0.0585745, + -0.0223018, + 0.131113, + 0.135643, + -0.0566208, + 0.142701, + 0.0342236, + -0.198664, + 0.0702607}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // seq_len = 1 + { + seq_len = 1; + migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + std::vector input_data1{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + p.compile(migraphx::make_target("ref")); + + auto hs_concat = p.eval({}).back(); + std::vector hs_data; + hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{0.079753, + -0.289854, + 0.160043, + 0.115056, + 0.294074, + -0.0319677, + -0.0955337, + 0.104168, + 0.022618, + -0.121195, + -0.4065, + -0.252054}; + EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); + } +} + +TEST_CASE(lstm_reverse) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{ + -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, + -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, + -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, + 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, + -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, + 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, + -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, + 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, + 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, + -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, + 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, + -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, + 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, + -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.5289, + 1.0986, + 0.6091, + 1.6462, + 0.8720, + 0.5349, + -0.1962, + -1.7416, + -0.9912, + 1.2831, + 1.0896, + -0.6959}; + + std::vector ic_data{-0.8323, + 0.3998, + 0.1831, + 0.5938, + 2.7096, + -0.1790, + 0.0022, + -0.8040, + 0.1578, + 0.0567, + 0.8069, + -0.5141}; + + std::vector pph_data{-0.8271, + -0.5683, + 0.4562, + -1.2545, + 1.2729, + -0.4082, + -0.4392, + -0.9406, + 0.7794, + 1.8194, + -0.5811, + 0.2166}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + float clip = 0.0f; + // reverse, concatenation of hidden states as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, + 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, + 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, + 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, + 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, + 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, + 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // reverse, sequence lengths are the same, but less than max_seq_lens + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + + migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, + 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, + 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, + 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, + 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, + 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, + 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // variable sequence lengths + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data{3, 2, 1}; + auto sql = mm->add_literal(seq_len_s, len_data); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449, + 0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761, + 0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0, + 0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // reverse, 3 args, last cell output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.443077, + -0.325425, + -0.249367, + -0.270812, + 0.122913, + 0.118537, + 0.0370199, + -0.0164687, + -0.00754759, + 0.141613, + 0.348002, + 0.667298}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // reverse, 3 args, 0 actv function + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", {}}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.443077, + -0.325425, + -0.249367, + -0.270812, + 0.122913, + 0.118537, + 0.0370199, + -0.0164687, + -0.00754759, + 0.141613, + 0.348002, + 0.667298}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } +} + +TEST_CASE(lstm_reverse_layout) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 1; + std::vector w_data{ + -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, + -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, + -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, + 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, + -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, + 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, + -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, + 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, + 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, + -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, + 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, + -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, + 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, + -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366, + 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332, + 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331, + -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310}; std::vector ih_data{1.5289, 1.0986, @@ -3593,14 +4487,15 @@ TEST_CASE(lstm_reverse) -0.5811, 0.2166}; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; float clip = 0.0f; + // reverse, concatenation of hidden states as program output { migraphx::program p; @@ -3614,7 +4509,13 @@ TEST_CASE(lstm_reverse) auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto und = mm->add_instruction(migraphx::make_op("undefined")); - mm->add_instruction( + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, @@ -3633,18 +4534,21 @@ TEST_CASE(lstm_reverse) ih, ic, pph); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ - -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, - 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, - 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, - 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, - 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, - 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, - 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; + -0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681, + -0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845, + 0.177273, -0.0774616, 0.789732, 0.128538, 0.20909, 0.0553812, 0.928866, + 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741, 0.130438, + 0.946669, 0.0868676, 0.044508, -0.373961, -0.224905, 0.32421, 0.344048, + 0.271694, -0.063456, 0.148524, 0.05108, -0.0234895, -0.0252804, 0.267356, + 0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211, -0.161537}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -3661,14 +4565,20 @@ TEST_CASE(lstm_reverse) auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); - migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}}; + migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}}; std::vector pad_data(pad_seq_s.elements(), 0.0f); auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); - auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p); migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; std::vector len_data(batch_size, static_cast(seq_len)); auto sql = mm->add_literal(seq_len_s, len_data); - mm->add_instruction( + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, @@ -3687,22 +4597,26 @@ TEST_CASE(lstm_reverse) ih, ic, pph); - p.compile(migraphx::make_target("ref")); - auto hs_concat = p.eval({}).back(); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{ - -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, - 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, - 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, - 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, - 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, - 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, - 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0}; + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{ + -0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681, + -0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845, + 0.177273, -0.0774616, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.789732, 0.128538, 0.20909, 0.0553812, + 0.928866, 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741, + 0.130438, 0.946669, 0.0868676, 0.044508, -0.373961, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.224905, + 0.32421, 0.344048, 0.271694, -0.063456, 0.148524, 0.05108, -0.0234895, + -0.0252804, 0.267356, 0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211, + -0.161537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -3722,7 +4636,13 @@ TEST_CASE(lstm_reverse) migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; std::vector len_data{3, 2, 1}; auto sql = mm->add_literal(seq_len_s, len_data); - mm->add_instruction( + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, @@ -3741,18 +4661,22 @@ TEST_CASE(lstm_reverse) ih, ic, pph); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ - -0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449, - 0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761, - 0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0, - 0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0}; + -0.126517, 0.0359124, 0.107453, -0.0617278, -0.168327, 0.00023761, 0.167567, + -0.0621982, -0.204545, 0.0146403, 0.210057, 0.0296268, 0, 0, + 0, 0, 0.911307, 0.11468, 0.114449, 0.0196755, 0.96657, + 0.0755112, 0.0620917, -0.264845, 0, 0, 0, 0, + 0, 0, 0, 0, -0.102969, 0.295872, 0.515859, + 0.246501, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -3763,6 +4687,10 @@ TEST_CASE(lstm_reverse) auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + auto hs = mm->add_instruction( migraphx::make_op( "lstm", @@ -3777,46 +4705,8 @@ TEST_CASE(lstm_reverse) seq, w, r); - mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); - - p.compile(migraphx::make_target("ref")); - auto hs_concat = p.eval({}).back(); - std::vector output_data; - hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.443077, - -0.325425, - -0.249367, - -0.270812, - 0.122913, - 0.118537, - 0.0370199, - -0.0164687, - -0.00754759, - 0.141613, - 0.348002, - 0.667298}; - EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); - } - - // reverse, 3 args, 0 actv function - { - migraphx::program p; - auto* mm = p.get_main_module(); - auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); - auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); - auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); - auto hs = mm->add_instruction( - migraphx::make_op( - "lstm", - {{"hidden_size", hidden_size}, - {"actv_func", {}}, - {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, - {"clip", clip}, - {"input_forget", 0}}), - seq, - w, - r); - mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); @@ -3900,19 +4790,214 @@ TEST_CASE(lstm_reverse_actv) 0.8069, -0.5141}; - std::vector pph_data{-0.8271, - -0.5683, - 0.4562, - -1.2545, - 1.2729, - -0.4082, - -0.4392, - -0.9406, - 0.7794, - 1.8194, - -0.5811, - 0.2166}; + std::vector pph_data{-0.8271, + -0.5683, + 0.4562, + -1.2545, + 1.2729, + -0.4082, + -0.4392, + -0.9406, + 0.7794, + 1.8194, + -0.5811, + 0.2166}; + + migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + float clip = 0.0f; + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value( + std::vector{migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, + 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231, + 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, + 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, + 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, + 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // reverse, 3 args, 2 actv functions + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{ + migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.132123, + -0.37531, + -0.12943, + -0.00798307, + -0.133882, + -0.0251383, + 0.0486486, + -0.0220606, + 0.292495, + 0.233866, + 0.48646, + 0.481844}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output + { + seq_len = 1; + std::vector input_data1{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; + migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); + + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{-0.104351, + -0.0471426, + -0.0905753, + 0.01506, + 0.059797, + 0.104239, + -0.0266768, + 0.0727547, + -0.146298, + 0.070535, + 0.327809, + 0.407388}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } +} + +TEST_CASE(lstm_bidirectional) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, + -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, + 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, + -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, + 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, + 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, + 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, + 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, + 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, + 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, + -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, + 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{ + 0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, + -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, + 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, + -0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, + 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, + -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, + -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, + 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, + 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, + 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, + 1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, + -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; + + std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, + -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, + -0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, + 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; + std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, + -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, + -0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562, + -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, + 0.7794, 1.8194, -0.5811, 0.2166}; + float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; @@ -3920,95 +5005,200 @@ TEST_CASE(lstm_reverse_actv) migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; - float clip = 0.0f; + + // concatenation of hidden states as program output { migraphx::program p; - auto* mm = p.get_main_module(); - auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); - - auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); - auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, {"actv_func", - migraphx::to_value( - std::vector{migraphx::make_op("sigmoid")})}, - {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"clip", clip}, {"input_forget", 0}}), seq, w, - r); + r, + bias, + und, + ih, + ic, + pph); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ - 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, - 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231, - 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, - 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, - 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, - 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; + 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, + 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, + 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, + 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, + -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, + 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032, + 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, + 0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, + -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, + 0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, + 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, + -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, + -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } - // reverse, 3 args, 2 actv functions + // last hidden state as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, + 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, + 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // last cell output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + p.compile(migraphx::make_target("ref")); + auto hs_concat = p.eval({}).back(); + std::vector output_data; + hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ + -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, + 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, + 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + } + + // 3 args, concatenation of hidden states as program output { migraphx::program p; auto* mm = p.get_main_module(); auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); - - auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); - auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); - auto hs = mm->add_instruction( + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, {"actv_func", - migraphx::to_value(std::vector{ - migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})}, - {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"clip", clip}, {"input_forget", 0}}), seq, w, r); - mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.132123, - -0.37531, - -0.12943, - -0.00798307, - -0.133882, - -0.0251383, - 0.0486486, - -0.0220606, - 0.292495, - 0.233866, - 0.48646, - 0.481844}; + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, + 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, + -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, + 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, + 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, + -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, + 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, + -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, + 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, + -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, + 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, + -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, + -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, + -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } - // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output + // sequence length is 1, contenation of hidden state as program output { - seq_len = 1; - std::vector input_data1{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; - migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::program p; auto* mm = p.get_main_module(); + seq_len = 1; + migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + std::vector input_data1{ + -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); - - auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); - auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); mm->add_instruction( migraphx::make_op( "lstm", @@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv) migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), migraphx::make_op("tanh"), migraphx::make_op("tanh")})}, - {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, {"clip", clip}, {"input_forget", 0}}), seq, @@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv) auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); - std::vector output_data_gold{-0.104351, - -0.0471426, - -0.0905753, - 0.01506, - 0.059797, - 0.104239, - -0.0266768, - 0.0727547, - -0.146298, - 0.070535, - 0.327809, - 0.407388}; + std::vector output_data_gold{ + -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, + -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, + -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, + -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } } -TEST_CASE(lstm_bidirectional) +TEST_CASE(lstm_bidirectional_layout) { std::size_t batch_size = 3; std::size_t seq_len = 4; @@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional) -0.4386, 0.4208, 0.0717, 0.3789}; std::vector input_data{ - -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, - -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, - 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, - 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; - - std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, - 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, - 1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, - -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; - - std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, - -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, - -0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, - 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; + -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366, + 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332, + 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331, + -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986, + 0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968, + 0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962, + 0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959}; + + std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998, + 0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831, + 2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498, + -1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141}; std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, @@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional) -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, 0.7794, 1.8194, -0.5811, 0.2166}; float clip = 0.0f; - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; - migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; // concatenation of hidden states as program output @@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional) auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto und = mm->add_instruction(migraphx::make_op("undefined")); - mm->add_instruction( + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, @@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional) ih, ic, pph); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ - 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, - 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, - 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, - 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, - 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, - -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, - 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032, - 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, - 0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, - -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, - 0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, - 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, - -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, - -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; + 0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138, + -0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549, + 0.178681, -0.266999, 0.0459032, 0.0414126, 0.272303, 0.0393149, -0.182201, + -0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746, + -0.185038, -0.026845, 0.177273, -0.0774616, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.789732, 0.128538, 0.20909, 0.0553812, 0.421857, 0.0459771, + -0.144955, 0.0720673, 0.928866, 0.113685, 0.220626, -0.0432316, 0.218258, + 0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565, 0.269741, 0.130438, + 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669, 0.0868676, 0.044508, + -0.373961, 0.022618, -0.121195, -0.4065, -0.252054, -0.224905, 0.32421, + 0.344048, 0.271694, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.063456, + 0.148524, 0.05108, -0.0234895, 0.103489, 0.0142918, -0.123408, 0.0401075, + -0.0252804, 0.267356, 0.146353, 0.0789186, 0.187761, 0.0501726, -0.121584, + 0.0606723, -0.0681467, 0.382748, 0.230211, -0.161537}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional) auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + auto hs = mm->add_instruction( migraphx::make_op( "lstm", @@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional) ih, ic, pph); - mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ - -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, - 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, - 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; + -0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188, + 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812, + 0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional) auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + auto hs = mm->add_instruction( migraphx::make_op( "lstm", @@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional) ih, ic, pph); - mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ - -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, - 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, - 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; + -0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334, + 0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813, + 0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional) auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); - mm->add_instruction( + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, @@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional) seq, w, r); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ - -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, - 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, - -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, - 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, - 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, - -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, - 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, - -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, - 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, - -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, - 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, - -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, - -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, - -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; + -0.0327039, -0.0543852, 0.114378, -0.0768855, -0.162851, -0.102647, -0.113827, + -0.142818, -0.0786602, -0.0613048, 0.179592, -0.071286, -0.123496, -0.153616, + -0.032874, -0.195349, -0.102509, -0.0372696, 0.252296, -0.144544, -0.1073, + -0.150145, 0.015065, -0.192699, -0.165194, -0.0372928, 0.273786, -0.100877, + -0.021205, -0.125423, 0.0206439, -0.187097, 0.0319021, -0.00298698, -0.0623361, + 0.0598866, 0.0513685, 0.0547876, 0.0201981, -0.00808453, 0.074206, 0.0124086, + -0.139544, 0.108016, 0.0192675, -0.108636, 0.098927, -0.140733, 0.00496085, + 0.0662588, -0.048577, -0.187329, -0.112764, -0.120496, 0.155754, 0.148256, + -0.0458544, -0.0401315, 0.0737483, -0.064505, -0.0051453, -0.0767618, -0.0735348, + -0.0826436, 0.101585, 0.0687269, -0.161725, -0.25617, -0.00520328, 0.0945081, + 0.264123, 0.410805, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.162602, + 0.0143099, -0.0455534, 0.0151574, 0.0855831, -0.0171894, -0.140202, 0.0828391, + 0.208491, 0.348432, 0.0291103, 0.230275, 0.136898, 0.00160891, -0.184812, + 0.147774, 0.214159, 0.262295, 0.0247127, 0.14472}; EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } @@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional) migraphx::program p; auto* mm = p.get_main_module(); seq_len = 1; - migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; + migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; std::vector input_data1{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1}); auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); - mm->add_instruction( + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( migraphx::make_op( "lstm", {{"hidden_size", hidden_size}, @@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional) seq, w, r); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); p.compile(migraphx::make_target("ref")); auto hs_concat = p.eval({}).back(); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + std::vector output_data_gold{ - -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, - -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, - -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, - -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; + -0.0327039, -0.0543852, 0.114378, -0.0768855, -0.104351, -0.0471426, + -0.0905753, 0.01506, 0.0319021, -0.00298698, -0.0623361, 0.0598866, + 0.059797, 0.104239, -0.0266768, 0.0727547, 0.101585, 0.0687269, + -0.161725, -0.25617, -0.146298, 0.070535, 0.327809, 0.407388}; + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); } } @@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens) } } +TEST_CASE(lstm_bidirectional_var_seq_lens_layout) +{ + std::size_t batch_size = 3; + std::size_t seq_len = 4; + std::size_t hidden_size = 4; + std::size_t input_size = 3; + std::size_t num_dirct = 2; + std::vector w_data{ + 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, + 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, + -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, + -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, + -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, + -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, + 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, + -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, + 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, + 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; + + std::vector r_data{ + 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, + -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, + -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, + -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, + 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, + 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, + -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, + 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, + 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, + 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, + 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, + -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, + 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; + + std::vector bias_data{ + 0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, + -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, + 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, + -0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, + 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, + -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, + -0.4386, 0.4208, 0.0717, 0.3789}; + + std::vector input_data{ + -0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366, + 0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332, + 1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331, + -1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310}; + + std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986, + 0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968, + 0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962, + 0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959}; + + std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998, + 0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831, + 2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498, + -1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141}; + + std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, + -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, + -0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562, + -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, + 0.7794, 1.8194, -0.5811, 0.2166}; + + float clip = 0.0f; + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + // concatenation of hidden states as program output + { + std::vector sl_data{1, 2, 3}; + migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data}); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto out_hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs); + auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs); + std::vector perm_hid{2, 0, 1, 3}; + out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), + out_hs); + lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho); + lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco); + mm->add_return({out_hs, lho, lco}); + p.compile(migraphx::make_target("ref")); + + auto outputs = p.eval({}); + auto arg_hs = outputs.front(); + auto arg_lho = outputs.at(1); + auto arg_lco = outputs.at(2); + + std::vector output_data; + std::vector last_output_data; + std::vector last_cell_data; + + arg_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); + arg_lho.visit([&](auto output) { last_output_data.assign(output.begin(), output.end()); }); + arg_lco.visit([&](auto output) { last_cell_data.assign(output.begin(), output.end()); }); + + std::vector output_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804, + 0.0745128, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0.294074, -0.0319677, -0.0955337, + 0.104168, 0.911307, 0.11468, 0.114449, 0.0196755, 0.421857, 0.0459771, + -0.144955, 0.0720673, 0.96657, 0.0755112, 0.0620917, -0.264845, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0.022618, -0.121195, -0.4065, -0.252054, -0.262807, 0.275286, + 0.358395, 0.266267, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.128254, + 0.125398, 0.0665142, -0.163651, 0.103489, 0.0142918, -0.123408, 0.0401075, + -0.0644683, 0.371512, 0.212431, -0.116131, 0, 0, 0, + 0, 0, 0, 0, 0}; + + std::vector last_output_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804, 0.0745128, + 0.421857, 0.0459771, -0.144955, 0.0720673, 0.911307, 0.11468, 0.114449, 0.0196755, + 0.103489, 0.0142918, -0.123408, 0.0401075, -0.262807, 0.275286, 0.358395, 0.266267}; + + std::vector last_cell_data_gold{ + 0.600582, -0.601197, 0.353558, 0.789097, -0.326822, 0.301121, 0.219523, 0.415242, + 0.737121, 0.134902, -0.303595, 0.241948, 2.08242, 0.442513, 0.187127, 0.0577626, + 0.391174, 0.0308845, -0.561745, 0.0730323, -0.611307, 0.55454, 0.4364, 0.509436}; + + EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold)); + EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold)); + EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold)); + } + + // last cell output as program output + { + migraphx::program p; + auto* mm = p.get_main_module(); + auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data}); + auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data}); + auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data}); + auto w = mm->add_literal(migraphx::literal{w_shape, w_data}); + auto r = mm->add_literal(migraphx::literal{r_shape, r_data}); + auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data}); + auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data}); + migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}}; + std::vector pad_data(pad_seq_s.elements(), 0.0f); + auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data}); + auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p); + migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}}; + std::vector len_data(batch_size, static_cast(seq_len)); + auto sql = mm->add_literal(seq_len_s, len_data); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}, + {"input_forget", 0}}), + seq, + w, + r, + bias, + sql, + ih, + ic, + pph); + auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + std::vector perm_hid{2, 0, 1, 3}; + hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho); + lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco); + mm->add_return({hs, lho, lco}); + p.compile(migraphx::make_target("ref")); + + auto outputs = p.eval({}); + auto res_hs = outputs.at(0); + auto res_lho = outputs.at(1); + auto res_lco = outputs.at(2); + std::vector hs_data; + std::vector lho_data; + std::vector lco_data; + res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); + res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); }); + res_lco.visit([&](auto output) { lco_data.assign(output.begin(), output.end()); }); + + std::vector hs_data_gold{ + 0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138, + -0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549, + 0.178681, -0.266999, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.182201, + -0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746, + -0.185038, -0.026845, 0.177273, -0.0774616, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0.294074, + -0.0319677, -0.0955337, 0.104168, 0.789732, 0.128538, 0.20909, 0.0553812, + 0.421857, 0.0459771, -0.144955, 0.0720673, 0.928866, 0.113685, 0.220626, + -0.0432316, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565, + 0.269741, 0.130438, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669, + 0.0868676, 0.044508, -0.373961, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0.022618, -0.121195, + -0.4065, -0.252054, -0.224905, 0.32421, 0.344048, 0.271694, -0.0300906, + -0.0890598, -0.135266, -0.0413375, -0.063456, 0.148524, 0.05108, -0.0234895, + 0.103489, 0.0142918, -0.123408, 0.0401075, -0.0252804, 0.267356, 0.146353, + 0.0789186, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.0681467, 0.382748, + 0.230211, -0.161537, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0}; + + std::vector lho_data_gold{ + -0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188, + 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812, + 0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694}; + + std::vector lco_data_gold{ + -0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334, + 0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813, + 0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713}; + + EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold)); + EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold)); + EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold)); + } +} + TEST_CASE(lstm_bidirectional_actv_func) { std::size_t batch_size = 3; diff --git a/test/verify/test_lstm_bidirct_3args_layout.cpp b/test/verify/test_lstm_bidirct_3args_layout.cpp new file mode 100644 index 00000000000..80fa1a52e4c --- /dev/null +++ b/test/verify/test_lstm_bidirct_3args_layout.cpp @@ -0,0 +1,77 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_3args_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_bidirct_last_layout.cpp b/test/verify/test_lstm_bidirct_last_layout.cpp new file mode 100644 index 00000000000..5d57cba5ef8 --- /dev/null +++ b/test/verify/test_lstm_bidirct_last_layout.cpp @@ -0,0 +1,95 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_bidirct_last_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 2; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto output = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_hs_layout.cpp b/test/verify/test_lstm_forward_hs_layout.cpp new file mode 100644 index 00000000000..aefa42bf6dc --- /dev/null +++ b/test/verify/test_lstm_forward_hs_layout.cpp @@ -0,0 +1,95 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_hs_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + auto und = mm->add_instruction(migraphx::make_op("undefined")); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + und, + ih, + ic, + pph); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_forward_last_layout.cpp b/test/verify/test_lstm_forward_last_layout.cpp new file mode 100644 index 00000000000..d882a005dd8 --- /dev/null +++ b/test/verify/test_lstm_forward_last_layout.cpp @@ -0,0 +1,97 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_forward_last_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; + migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape l_shape{migraphx::shape::int32_type, {batch_size}}; + migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}}; + migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; + + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + auto bias = mm->add_parameter("bias", b_shape); + auto ih = mm->add_parameter("ih", ih_shape); + auto len = mm->add_literal(migraphx::literal(l_shape, {1, 2})); + auto ic = mm->add_parameter("ic", ic_shape); + auto pph = mm->add_parameter("pph", pph_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih); + ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic); + + auto output = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r, + bias, + len, + ih, + ic, + pph); + auto last_output = + mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), output, len); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_reverse_3args_cell_output_layout.cpp b/test/verify/test_lstm_reverse_3args_cell_output_layout.cpp new file mode 100644 index 00000000000..1f2c25c45f3 --- /dev/null +++ b/test/verify/test_lstm_reverse_3args_cell_output_layout.cpp @@ -0,0 +1,78 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_reverse_3args_cell_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r); + auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_reverse_3args_layout.cpp b/test/verify/test_lstm_reverse_3args_layout.cpp new file mode 100644 index 00000000000..19433a696f4 --- /dev/null +++ b/test/verify/test_lstm_reverse_3args_layout.cpp @@ -0,0 +1,78 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_reverse_3args_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)}, + {"clip", clip}}), + seq, + w, + r); + std::vector perm_hid{2, 0, 1, 3}; + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + + return p; + } + std::string section() const { return "rnn"; } +}; diff --git a/test/verify/test_lstm_three_outputs_layout.cpp b/test/verify/test_lstm_three_outputs_layout.cpp new file mode 100644 index 00000000000..adc858b5910 --- /dev/null +++ b/test/verify/test_lstm_three_outputs_layout.cpp @@ -0,0 +1,85 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +#include + +#include + +struct test_lstm_three_outputs_layout : verify_program +{ + migraphx::program create_program() const + { + std::size_t batch_size = 2; + std::size_t seq_len = 3; + std::size_t hidden_size = 5; + std::size_t input_size = 8; + std::size_t num_dirct = 1; + float clip = 0.0f; + + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}}; + migraphx::shape w_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, input_size}}; + migraphx::shape r_shape{migraphx::shape::float_type, + {num_dirct, 4 * hidden_size, hidden_size}}; + auto seq = mm->add_parameter("seq", in_shape); + auto w = mm->add_parameter("w", w_shape); + auto r = mm->add_parameter("r", r_shape); + + std::vector perm{1, 0, 2}; + seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq); + + auto hs = mm->add_instruction( + migraphx::make_op( + "lstm", + {{"hidden_size", hidden_size}, + {"actv_func", + migraphx::to_value(std::vector{migraphx::make_op("sigmoid"), + migraphx::make_op("tanh"), + migraphx::make_op("tanh")})}, + {"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)}, + {"clip", clip}}), + seq, + w, + r); + auto last_hs = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs); + auto last_cell = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs); + std::vector perm_hid{2, 0, 1, 3}; + hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs); + last_hs = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_hs); + last_cell = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_cell); + mm->add_return({hs, last_hs, last_cell}); + + return p; + } + std::string section() const { return "rnn"; } +};