From a11b9f3c1d2f90bc6b4d11222baeb219dd19432d Mon Sep 17 00:00:00 2001 From: Ziming Miao Date: Tue, 6 Dec 2022 17:46:08 +0800 Subject: [PATCH] Some fixes (#504) * fix int64 mapping & fix onehot ir * fix convert to boolean type --- src/nnfusion/core/kernels/kernel_emitter.cpp | 11 ++++++++ .../generic_op/generic_op_define/Convert.cpp | 25 +++++++++++++------ .../generic_op/generic_op_define/OneHot.cpp | 7 +++--- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/nnfusion/core/kernels/kernel_emitter.cpp b/src/nnfusion/core/kernels/kernel_emitter.cpp index 3e68e36a9..8c419f690 100644 --- a/src/nnfusion/core/kernels/kernel_emitter.cpp +++ b/src/nnfusion/core/kernels/kernel_emitter.cpp @@ -259,8 +259,19 @@ FunctionUnit_p KernelEmitter::emit_source() if (lp && FLAGS_fantares_mode && FLAGS_fdefault_device != "HLSL") { auto lp_str = lp->get_code(); + // avoid modify naming string + lp_str = replace_sub_str(lp_str, "_uint64_t", "@underline_unsigned_integer@"); + lp_str = replace_sub_str(lp_str, "_int64_t", "@underline_integer@"); + lp_str = replace_sub_str(lp_str, "uint64_t_", "@unsigned_integer_underline@"); + lp_str = replace_sub_str(lp_str, "int64_t_", "@integer_underline@"); + lp_str = replace_sub_str(lp_str, "uint64_t", "unsigned long long"); lp_str = replace_sub_str(lp_str, "int64_t", "long long"); + + lp_str = replace_sub_str(lp_str, "@underline_unsigned_integer@", "_uint64_t"); + lp_str = replace_sub_str(lp_str, "@underline_integer@", "_int64_t"); + lp_str = replace_sub_str(lp_str, "@unsigned_integer_underline@", "uint64_t_"); + lp_str = replace_sub_str(lp_str, "@integer_underline@", "int64_t_"); lp->modify_code(lp_str); } return lp; diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/Convert.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/Convert.cpp index 886cb043e..0bf103324 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/Convert.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/Convert.cpp @@ -38,11 +38,22 @@ REGISTER_OP(Convert) NNFUSION_CHECK(ret == true) << "cast type is not supported: " << op->get_convert_element_type().c_type_string(); out_dtype = out_dtype == "char" ? "int8" : out_dtype; - - return op::create_code_from_template( - "@output0@@data_layout@ = @input0@@data_layout@.cast(`@out_dtype@`);", - {{"data_layout", - vector_to_string>( - op::create_layout_from_dims(gnode->get_output_shape(0)))}, - {"out_dtype", out_dtype}}); + if (op->get_convert_element_type() == element::boolean) + { + return op::create_code_from_template( + "@output0@@data_layout@ = (@input0@@data_layout@ != 0).cast(`@out_dtype@`);", + {{"data_layout", + vector_to_string>( + op::create_layout_from_dims(gnode->get_output_shape(0)))}, + {"out_dtype", out_dtype}}); + } + else + { + return op::create_code_from_template( + "@output0@@data_layout@ = @input0@@data_layout@.cast(`@out_dtype@`);", + {{"data_layout", + vector_to_string>( + op::create_layout_from_dims(gnode->get_output_shape(0)))}, + {"out_dtype", out_dtype}}); + } }); diff --git a/src/nnfusion/core/operators/generic_op/generic_op_define/OneHot.cpp b/src/nnfusion/core/operators/generic_op/generic_op_define/OneHot.cpp index e91cc1dac..5db3c0205 100644 --- a/src/nnfusion/core/operators/generic_op/generic_op_define/OneHot.cpp +++ b/src/nnfusion/core/operators/generic_op/generic_op_define/OneHot.cpp @@ -45,7 +45,7 @@ REGISTER_OP(OneHot) std::string dtype; bool ret = element::Type::nnfusion_element_type_to_dtype_string(gnode->get_element_type(), dtype); - NNFUSION_CHECK(ret) << "Unsupport data type: " << gnode->get_element_type(); + NNFUSION_CHECK(ret) << "Unsupport data type: " << gnode->get_element_type(); auto input0_layout = op::create_layout_from_dims(gnode->get_input_shape(0)); auto output_layout = input0_layout; @@ -55,7 +55,7 @@ REGISTER_OP(OneHot) std::string expr = "@output0@@output_layout@ = const(@on_value@).when([@input0@@input0_layout@ == " "@axis@, @input0@@input0_layout@ + @depth@ == @axis@], @off_value@, " - "merge_op=`any`) where @axis@ in @depth@;"; + "merge_op=`any`).cast(`@dtype@`) where @axis@ in @depth@;"; return op::create_code_from_template( expr, @@ -64,5 +64,6 @@ REGISTER_OP(OneHot) {"depth", depth}, {"on_value", on_value}, {"off_value", off_value}, - {"axis", output_layout[axis]}}); + {"axis", output_layout[axis]}, + {"dtype", dtype}}); });