Skip to content

Commit

Permalink
Move auto_contiguous before lowering (#3343)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Aug 29, 2024
1 parent e81e3f7 commit 51a24d8
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ struct MIGRAPHX_EXPORT shape
static std::string cpp_type(type_t t);

static bool is_integral(type_t t);
static bool is_compatible(const shape& actual, const shape& expected);

shape();
shape(type_t t);
Expand Down
24 changes: 24 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,30 @@ bool shape::is_integral(shape::type_t t)
return result;
}

bool shape::is_compatible(const shape& actual, const shape& expected)
{
// Check subshapes
if(expected.type() == shape::tuple_type)
return migraphx::equal(actual.sub_shapes(), expected.sub_shapes(), &is_compatible);
if(actual == expected)
return true;
if(actual.type() != expected.type())
return false;
// Only the expected can be dynamic
if(expected.dynamic())
return actual.ndim() == expected.ndim();
if(actual.dynamic())
return false;
if(actual.lens() != expected.lens())
return false;
// Check strides from dimensions that are not 1
return all_of(range(actual.lens().size()), [&](auto i) {
if(actual.lens()[i] == 1)
return true;
return actual.strides()[i] == expected.strides()[i];
});
}

shape::shape() : impl(shape_impl::default_shape()) {}

shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/code_object_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ shape code_object_op::compute_shape(std::vector<shape> inputs) const
std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) {
return s.normalize_standard();
});
if(flatten(einputs) != flatten(inputs))
if(not migraphx::equal(flatten(einputs), flatten(inputs), &shape::is_compatible))
MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" +
to_string_range(inputs) + "]");
return output;
Expand Down
3 changes: 2 additions & 1 deletion src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,8 @@ struct find_pointwise_mlir
m->add_return({ret});

auto inputs = find_inputs(map_ins, &mpm.get_module(), m);
mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m});
mpm.get_module().replace_instruction(
ins, ins->get_operator(), mlir_contiguous(mpm, inputs), {m});
}
};

Expand Down
4 changes: 3 additions & 1 deletion src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{},
enable_pass(not enabled(MIGRAPHX_ENABLE_NHWC{}), auto_contiguous{}),
eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops},
dead_code_elimination{},
rewrite_reduce{},
Expand All @@ -172,6 +172,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
fuse_concat{},
dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), auto_contiguous{}),
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{},
Expand Down
63 changes: 63 additions & 0 deletions test/shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,5 +1099,68 @@ TEST_CASE(multi_within_bounds)
std::vector<std::size_t> multi_4 = {1, 2, 1};
EXPECT(not in_shape.multi_within_bounds(multi_4));
}
TEST_CASE(shape_is_compatible_diff_strides)
{
migraphx::shape actual{migraphx::shape::float_type, {1, 1, 8}, {8, 8, 1}};
migraphx::shape expected{migraphx::shape::float_type, {1, 1, 8}, {1, 1, 1}};
EXPECT(actual != expected);
EXPECT(migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_diff_lens)
{
migraphx::shape actual{migraphx::shape::float_type, {1, 2, 8}, {8, 8, 1}};
migraphx::shape expected{migraphx::shape::float_type, {1, 1, 8}, {8, 8, 1}};
EXPECT(actual != expected);
EXPECT(not migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_diff_type)
{
migraphx::shape actual{migraphx::shape::float_type, {1, 2, 8}};
migraphx::shape expected{migraphx::shape::half_type, {1, 2, 8}};
EXPECT(actual != expected);
EXPECT(not migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_dynamic)
{
migraphx::shape actual{migraphx::shape::float_type, {1, 2, 2, 4}};
migraphx::shape expected{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(actual != expected);
EXPECT(migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_dynamic_actual)
{
migraphx::shape actual{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
migraphx::shape expected{migraphx::shape::float_type, {1, 2, 2, 4}};
EXPECT(actual != expected);
EXPECT(not migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_dynamic_diff_type)
{
migraphx::shape actual{migraphx::shape::float_type, {1, 2, 2, 4}};
migraphx::shape expected{migraphx::shape::half_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(actual != expected);
EXPECT(not migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_dynamic_diff_rank)
{
migraphx::shape actual{migraphx::shape::float_type, {1, 2, 2}};
migraphx::shape expected{migraphx::shape::half_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}};
EXPECT(actual != expected);
EXPECT(not migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_diff_strides_tuple)
{
migraphx::shape actual{migraphx::shape{migraphx::shape::float_type, {1, 1, 8}, {8, 8, 1}}};
migraphx::shape expected{migraphx::shape{migraphx::shape::float_type, {1, 1, 8}, {1, 1, 1}}};
EXPECT(actual != expected);
EXPECT(migraphx::shape::is_compatible(actual, expected));
}
TEST_CASE(shape_is_compatible_diff_lens_tuple)
{
migraphx::shape actual{migraphx::shape{migraphx::shape::float_type, {1, 2, 8}}};
migraphx::shape expected{migraphx::shape{migraphx::shape::float_type, {1, 1, 8}}};
EXPECT(actual != expected);
EXPECT(not migraphx::shape::is_compatible(actual, expected));
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 51a24d8

Please sign in to comment.