diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index da54c264fc2..1db85c4f934 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -110,6 +110,16 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti { unsupported_fp8_ops.insert("dot"); } + // add all device kernels + unsupported_fp8_ops.insert("logsoftmax"); + unsupported_fp8_ops.insert("nonzero"); + unsupported_fp8_ops.insert("prefix_scan_sum"); + unsupported_fp8_ops.insert("scatter_none"); + unsupported_fp8_ops.insert("topk"); + unsupported_fp8_ops.insert("rnn_var_sl_shift_output"); + unsupported_fp8_ops.insert("multinomial"); + unsupported_fp8_ops.insert("argmax"); + unsupported_fp8_ops.insert("argmin"); // clang-format off return { diff --git a/test/verify/gemm_2args_mm_8.cpp b/test/verify/gemm_2args_mm_8.cpp index 982dbc003ed..5fc26ed98aa 100644 --- a/test/verify/gemm_2args_mm_8.cpp +++ b/test/verify/gemm_2args_mm_8.cpp @@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program> }; template struct gemm_2args_mm_8; -// template struct gemm_2args_mm_8; +// template struct gemm_2args_mm_8; // fails with CK, issue#2514 template struct gemm_2args_mm_8; diff --git a/test/verify/gemm_add_broadcast2.cpp b/test/verify/gemm_add_broadcast2.cpp index 15f35ad0628..a01ce1a79ec 100644 --- a/test/verify/gemm_add_broadcast2.cpp +++ b/test/verify/gemm_add_broadcast2.cpp @@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program> }; template struct gemm_add_broadcast2; -// template struct gemm_add_broadcast2; +// template struct gemm_add_broadcast2; // fails with CK, issue#2514 template struct gemm_add_broadcast2; diff --git a/test/verify/test_arg_ops.cpp b/test/verify/test_arg_ops.cpp index fdc237587bc..c4b60f327ee 100644 --- a/test/verify/test_arg_ops.cpp +++ b/test/verify/test_arg_ops.cpp @@ -29,14 +29,14 @@ #include #include -template -struct test_arg_ops : verify_program> +template +struct test_arg_ops : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {2, 1, 4, 1025}}; + migraphx::shape s{DType, {2, 1, 4, 1025}}; auto param = mm->add_parameter("data", s); switch(NonStdShape) { @@ -59,106 +59,211 @@ struct test_arg_ops : verify_program; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // transpose argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // broadcast argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // broadcast argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // slice argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // slice argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // default case, standard shape argmax tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; // default case, standard shape argmin tests -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; -template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; + +// transpose argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// transpose argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// broadcast argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// broadcast argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// slice argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// slice argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// default case, standard shape argmax tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +// default case, standard shape argmin tests +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; +template struct test_arg_ops; diff --git a/test/verify/test_contiguous.cpp b/test/verify/test_contiguous.cpp index 03f53fde5c4..efada842b03 100644 --- a/test/verify/test_contiguous.cpp +++ b/test/verify/test_contiguous.cpp @@ -29,16 +29,20 @@ #include -struct test_contiguous : verify_program +template +struct test_contiguous : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {4, 4, 4, 3}, {48, 4, 1, 16}}; + migraphx::shape s{DType, {4, 4, 4, 3}, {48, 4, 1, 16}}; auto x = mm->add_parameter("x", s); mm->add_instruction(migraphx::make_op("contiguous"), x); assert(p.get_output_shapes().back().standard()); return p; } }; + +template struct test_contiguous; +template struct test_contiguous; diff --git a/test/verify/test_logsoftmax.cpp b/test/verify/test_logsoftmax.cpp index ad8b7fb2d66..bd56887751e 100644 --- a/test/verify/test_logsoftmax.cpp +++ b/test/verify/test_logsoftmax.cpp @@ -50,3 +50,7 @@ template struct test_logsoftmax<1, migraphx::shape::half_type>; template struct test_logsoftmax<0, migraphx::shape::half_type>; template struct test_logsoftmax<2, migraphx::shape::half_type>; template struct test_logsoftmax<3, migraphx::shape::half_type>; +template struct test_logsoftmax<0, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_logsoftmax<1, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_logsoftmax<2, migraphx::shape::fp8e4m3fnuz_type>; +template struct test_logsoftmax<3, migraphx::shape::fp8e4m3fnuz_type>; diff --git a/test/verify/test_multinomial.cpp b/test/verify/test_multinomial.cpp index c7c294df0bd..86ecf6681a1 100644 --- a/test/verify/test_multinomial.cpp +++ b/test/verify/test_multinomial.cpp @@ -27,7 +27,8 @@ #include #include -struct test_multinomial : verify_program +template +struct test_multinomial : verify_program> { migraphx::program create_program() const { @@ -40,10 +41,10 @@ struct test_multinomial : verify_program std::uniform_real_distribution<> dis(0.0, 1.0); std::vector rand_samples(batch_size * sample_size); std::generate(rand_samples.begin(), rand_samples.end(), [&]() { return dis(gen); }); - migraphx::shape rs{migraphx::shape::float_type, {batch_size, sample_size}}; + migraphx::shape rs{DType, {batch_size, sample_size}}; auto rs_lit = mm->add_literal(migraphx::literal{rs, rand_samples}); - migraphx::shape s{migraphx::shape::float_type, {batch_size, 5}}; + migraphx::shape s{DType, {batch_size, 5}}; auto input = mm->add_parameter("input", s); auto maxes = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), input); @@ -58,3 +59,8 @@ struct test_multinomial : verify_program return p; } }; + +template struct test_multinomial; +template struct test_multinomial; +// This fails, need to figure out why +// template struct test_multinomial; diff --git a/test/verify/test_nonzero.cpp b/test/verify/test_nonzero.cpp index 47409a91b4a..eb939986a78 100644 --- a/test/verify/test_nonzero.cpp +++ b/test/verify/test_nonzero.cpp @@ -27,13 +27,14 @@ #include #include -struct test_nonzero : verify_program +template +struct test_nonzero : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}}; + migraphx::shape s{DType, {2, 3, 4, 5}}; auto x = mm->add_parameter("data", s); auto r = mm->add_instruction(migraphx::make_op("nonzero"), x); mm->add_return({r}); @@ -41,3 +42,7 @@ struct test_nonzero : verify_program return p; } }; + +template struct test_nonzero; +template struct test_nonzero; +template struct test_nonzero; diff --git a/test/verify/test_nonzero_half.cpp b/test/verify/test_nonzero_half.cpp deleted file mode 100644 index 4621842eaaf..00000000000 --- a/test/verify/test_nonzero_half.cpp +++ /dev/null @@ -1,43 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_nonzero_half : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::half_type, {3, 4, 3, 5}}; - auto x = mm->add_parameter("data", s); - auto r = mm->add_instruction(migraphx::make_op("nonzero"), x); - mm->add_return({r}); - - return p; - } -}; diff --git a/test/verify/test_prefix_scan_sum_2d.cpp b/test/verify/test_prefix_scan_sum_2d.cpp index cd30e10caf1..8999c498616 100644 --- a/test/verify/test_prefix_scan_sum_2d.cpp +++ b/test/verify/test_prefix_scan_sum_2d.cpp @@ -23,16 +23,18 @@ */ #include "verify_program.hpp" #include +#include #include #include -struct test_prefix_scan_sum_2d_small : verify_program +template +struct test_prefix_scan_sum_2d_small : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {1}}; + migraphx::shape s{DType, {1}}; auto x = mm->add_parameter("x", s); auto xb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), x); @@ -42,16 +44,25 @@ struct test_prefix_scan_sum_2d_small : verify_program +template struct test_prefix_scan_sum_2d_small; +template struct test_prefix_scan_sum_2d_small; +template struct test_prefix_scan_sum_2d_small; + +template +struct test_prefix_scan_sum_2d_large : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {3, 1000}}; + migraphx::shape s{DType, {3, 1000}}; auto x = mm->add_parameter("x", s); mm->add_instruction( migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), x); return p; } }; + +template struct test_prefix_scan_sum_2d_large; +template struct test_prefix_scan_sum_2d_large; +template struct test_prefix_scan_sum_2d_large; diff --git a/test/verify/test_reverse.cpp b/test/verify/test_reverse.cpp index 8ac7c85edd7..c6af63c38a5 100644 --- a/test/verify/test_reverse.cpp +++ b/test/verify/test_reverse.cpp @@ -26,16 +26,21 @@ #include #include -struct test_reverse : verify_program +template +struct test_reverse : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {4, 16}}; + migraphx::shape s{DType, {4, 16}}; auto a0 = mm->add_parameter("data", s); std::vector axis = {0}; mm->add_instruction(migraphx::make_op("reverse", {{"axes", axis}}), a0); return p; } }; + +template struct test_reverse; +template struct test_reverse; +template struct test_reverse; diff --git a/test/verify/test_rnn_sql_1.cpp b/test/verify/test_rnn_sql_1.cpp index 25a71ec01d2..559e377380e 100644 --- a/test/verify/test_rnn_sql_1.cpp +++ b/test/verify/test_rnn_sql_1.cpp @@ -31,7 +31,8 @@ #include -struct test_rnn_sql_1 : verify_program +template +struct test_rnn_sql_1 : verify_program> { migraphx::program create_program() const { @@ -44,12 +45,12 @@ struct test_rnn_sql_1 : verify_program migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; - migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; - migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; - migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; + migraphx::shape in_shape{DType, {seq_len, batch_size, input_size}}; + migraphx::shape w_shape{DType, {num_dirct, hidden_size, input_size}}; + migraphx::shape r_shape{DType, {num_dirct, hidden_size, hidden_size}}; + migraphx::shape b_shape{DType, {num_dirct, 2 * hidden_size}}; migraphx::shape s_shape{migraphx::shape::int32_type, {batch_size}}; - migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; + migraphx::shape ih_shape{DType, {num_dirct, batch_size, hidden_size}}; auto seq = mm->add_parameter("seq", in_shape); auto w = mm->add_parameter("w", w_shape); @@ -81,3 +82,7 @@ struct test_rnn_sql_1 : verify_program } std::string section() const { return "rnn"; } }; + +template struct test_rnn_sql_1; +template struct test_rnn_sql_1; +template struct test_rnn_sql_1; diff --git a/test/verify/test_scatter0.cpp b/test/verify/test_scatter0.cpp index da2650a2cbd..f853f199764 100644 --- a/test/verify/test_scatter0.cpp +++ b/test/verify/test_scatter0.cpp @@ -27,16 +27,17 @@ #include #include -struct test_scatter0 : verify_program +template +struct test_scatter0 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape sd{migraphx::shape::float_type, {3, 3}}; + migraphx::shape sd{DType, {3, 3}}; migraphx::shape si{migraphx::shape::int32_type, {2, 3}}; std::vector vi = {1, 0, 2, 0, 2, 1}; - migraphx::shape su{migraphx::shape::float_type, {2, 3}}; + migraphx::shape su{DType, {2, 3}}; auto pd = mm->add_parameter("data", sd); auto li = mm->add_literal(migraphx::literal{si, vi}); @@ -47,3 +48,7 @@ struct test_scatter0 : verify_program return p; } }; + +template struct test_scatter0; +template struct test_scatter0; +template struct test_scatter0; diff --git a/test/verify/test_topk_0.cpp b/test/verify/test_topk_0.cpp index 8c99154ac08..643de2ba05c 100644 --- a/test/verify/test_topk_0.cpp +++ b/test/verify/test_topk_0.cpp @@ -27,13 +27,14 @@ #include #include -struct test_topk_0 : verify_program +template +struct test_topk_0 : verify_program> { migraphx::program create_program() const { migraphx::program p; auto* mm = p.get_main_module(); - migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + migraphx::shape s{DType, {3, 5}}; auto data = mm->add_parameter("data", s); auto r = mm->add_instruction( migraphx::make_op("topk", {{"axis", 1}, {"k", 4}, {"largest", 1}}), data); @@ -43,3 +44,7 @@ struct test_topk_0 : verify_program return p; } }; + +template struct test_topk_0; +template struct test_topk_0; +template struct test_topk_0;