Skip to content

Commit

Permalink
FP Exception for dim (-1) in reshape operator (#2865) (#2878)
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Apr 3, 2024
1 parent 05a7707 commit dc9351b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,16 @@ struct mlir_program
return "migraphx." + ins->name();
}

static value get_operator_value(const operation& op)
static value get_operator_value(instruction_ref ins)
{
auto v = op.to_value();
const operation& op = ins->get_operator();
auto v = op.to_value();

// Reshape operator can have dim 0 or -1.
// Avoid passing those on to MLIR:
if(op.name() == "reshape")
v["dims"] = ins->get_shape().lens();

if(op.name() == "convolution" or op.name() == "quant_convolution")
{
// Adjust symetrical padding
Expand Down Expand Up @@ -668,7 +675,7 @@ struct mlir_program
}
auto name = get_name(ins);
auto ops = create_operation_state(name);
ops.add_attribute_value(get_operator_value(ins->get_operator()));
ops.add_attribute_value(get_operator_value(ins));
if(ins->name() != "@return")
ops.add_results({get_shape(ins)});
if(ins->name() == "@literal")
Expand Down
26 changes: 26 additions & 0 deletions test/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,32 @@ module {
EXPECT(verify_mlir(m));
}

// The following test checks that a dimension -1, within reshape operator is handled properly..
TEST_CASE(conv_reshape_dim_minus_one)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_convolution_reshape(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped<1x4x1x2xf32, 8x2x2x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> -> <1x2x2x2xf32, 8x4x2x1>
%1 = migraphx.reshape %0 {dims = [1, 4, 1, 2]} : <1x2x2x2xf32, 8x4x2x1> -> <1x4x1x2xf32, 8x2x2x1>
return %1 : !migraphx.shaped<1x4x1x2xf32, 8x2x2x1>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto w = m.add_parameter("w", {migraphx::shape::float_type, {2, 8, 3, 3}});
auto conv = m.add_instruction(migraphx::make_op("convolution"), x, w);
auto reshape = m.add_instruction(migraphx::make_op("reshape", {{"dims", {1, -1, 1, 2}}}), conv);
m.add_return({reshape});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}

TEST_CASE(quant_dot_add)
{
const std::string mlir_output = R"__migraphx__(
Expand Down

0 comments on commit dc9351b

Please sign in to comment.