Skip to content

Commit

Permalink
Some fixes (#504)
Browse files Browse the repository at this point in the history
* fix int64 mapping & fix onehot ir

* fix convert to boolean type
mzmssg authored Dec 6, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 2b95e46 commit a11b9f3
Showing 3 changed files with 33 additions and 10 deletions.
11 changes: 11 additions & 0 deletions src/nnfusion/core/kernels/kernel_emitter.cpp
Original file line number Diff line number Diff line change
@@ -259,8 +259,19 @@ FunctionUnit_p KernelEmitter::emit_source()
if (lp && FLAGS_fantares_mode && FLAGS_fdefault_device != "HLSL")
{
auto lp_str = lp->get_code();
// avoid modify naming string
lp_str = replace_sub_str(lp_str, "_uint64_t", "@underline_unsigned_integer@");
lp_str = replace_sub_str(lp_str, "_int64_t", "@underline_integer@");
lp_str = replace_sub_str(lp_str, "uint64_t_", "@unsigned_integer_underline@");
lp_str = replace_sub_str(lp_str, "int64_t_", "@integer_underline@");

lp_str = replace_sub_str(lp_str, "uint64_t", "unsigned long long");
lp_str = replace_sub_str(lp_str, "int64_t", "long long");

lp_str = replace_sub_str(lp_str, "@underline_unsigned_integer@", "_uint64_t");
lp_str = replace_sub_str(lp_str, "@underline_integer@", "_int64_t");
lp_str = replace_sub_str(lp_str, "@unsigned_integer_underline@", "uint64_t_");
lp_str = replace_sub_str(lp_str, "@integer_underline@", "int64_t_");
lp->modify_code(lp_str);
}
return lp;
Original file line number Diff line number Diff line change
@@ -38,11 +38,22 @@ REGISTER_OP(Convert)
NNFUSION_CHECK(ret == true) << "cast type is not supported: "
<< op->get_convert_element_type().c_type_string();
out_dtype = out_dtype == "char" ? "int8" : out_dtype;

return op::create_code_from_template(
"@output0@@data_layout@ = @input0@@[email protected](`@out_dtype@`);",
{{"data_layout",
vector_to_string<std::vector<std::string>>(
op::create_layout_from_dims(gnode->get_output_shape(0)))},
{"out_dtype", out_dtype}});
if (op->get_convert_element_type() == element::boolean)
{
return op::create_code_from_template(
"@output0@@data_layout@ = (@input0@@data_layout@ != 0).cast(`@out_dtype@`);",
{{"data_layout",
vector_to_string<std::vector<std::string>>(
op::create_layout_from_dims(gnode->get_output_shape(0)))},
{"out_dtype", out_dtype}});
}
else
{
return op::create_code_from_template(
"@output0@@data_layout@ = @input0@@[email protected](`@out_dtype@`);",
{{"data_layout",
vector_to_string<std::vector<std::string>>(
op::create_layout_from_dims(gnode->get_output_shape(0)))},
{"out_dtype", out_dtype}});
}
});
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@ REGISTER_OP(OneHot)
std::string dtype;
bool ret =
element::Type::nnfusion_element_type_to_dtype_string(gnode->get_element_type(), dtype);
NNFUSION_CHECK(ret) << "Unsupport data type: " << gnode->get_element_type();
NNFUSION_CHECK(ret) << "Unsupport data type: " << gnode->get_element_type();

auto input0_layout = op::create_layout_from_dims(gnode->get_input_shape(0));
auto output_layout = input0_layout;
@@ -55,7 +55,7 @@ REGISTER_OP(OneHot)
std::string expr =
"@output0@@output_layout@ = const(@on_value@).when([@input0@@input0_layout@ == "
"@axis@, @input0@@input0_layout@ + @depth@ == @axis@], @off_value@, "
"merge_op=`any`) where @axis@ in @depth@;";
"merge_op=`any`).cast(`@dtype@`) where @axis@ in @depth@;";

return op::create_code_from_template(
expr,
@@ -64,5 +64,6 @@ REGISTER_OP(OneHot)
{"depth", depth},
{"on_value", on_value},
{"off_value", off_value},
{"axis", output_layout[axis]}});
{"axis", output_layout[axis]},
{"dtype", dtype}});
});

0 comments on commit a11b9f3

Please sign in to comment.