Skip to content

Commit

Permalink
Add changes for contiguous transpose gemm fusion for hipblaslt (#3706)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca authored Dec 14, 2024
1 parent e8bfc2c commit b0072d9
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 12 deletions.
87 changes: 77 additions & 10 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,16 +753,8 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise
};
#endif

struct find_contiguous_tranpose_gemm
struct contiguous_transpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}

template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
Expand All @@ -773,6 +765,17 @@ struct find_contiguous_tranpose_gemm
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
};

struct find_contiguous_transpose_rocblas_gemm : contiguous_transpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}

void apply(module& m, const match::matcher_result& r) const
{
Expand Down Expand Up @@ -811,6 +814,67 @@ struct find_contiguous_tranpose_gemm
}
};

#if MIGRAPHX_USE_HIPBLASLT
struct find_contiguous_transpose_hip_gemm : contiguous_transpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(
match::name("gpu::hipblaslt_op")(match::used_once()).bind("hip_gemm")))
.bind("transpose")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["hip_gemm"];
auto gemm_op = any_cast<hipblaslt_op>(gemm_ins->get_operator()).op;

if(gemm_op.name() != "gpu::hip_gemm")
return;

auto gemm = any_cast<hip_gemm<op::dot>>(gemm_op);

auto alloc = gemm_ins->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);

if(perm.size() < 3)
return;

if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;

auto lens = gemm_ins->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;

gemm.trans_batch = 1;

auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc =
m.insert_instruction(gemm_ins, make_op("allocate", {{"shape", to_value(s)}}));

auto alloc_transpose = m.insert_instruction(
gemm_ins, make_op("transpose", {{"permutation", perm}}), new_alloc);

auto inputs = gemm_ins->inputs();
inputs.back() = alloc_transpose;
operation new_gemm_op = gemm;
auto new_gemm = m.insert_instruction(
gemm_ins, make_op("gpu::hipblaslt_op", {{"op", to_value(new_gemm_op)}}), inputs);

auto gemm_transpoe = m.insert_instruction(gemm_ins, transpose->get_operator(), new_gemm);

m.replace_instruction(ins, gemm_transpoe);
}
};
#endif

struct find_commutative_broadcast
{
auto matcher() const
Expand Down Expand Up @@ -980,7 +1044,10 @@ void fuse_ops::apply(module& m) const
match::find_matches(m,
find_layernorm_pointwise{},
find_concat_pointwise{},
find_contiguous_tranpose_gemm{},
find_contiguous_transpose_rocblas_gemm{},
#if MIGRAPHX_USE_HIPBLASLT
find_contiguous_transpose_hip_gemm{},
#endif
find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
}
Expand Down
14 changes: 14 additions & 0 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
#include <migraphx/permutation.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -111,6 +112,19 @@ void blas_shape_hip(const shape& in_shape)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}

shape transpose_batch_hip(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}

static bool is_transposed_hip(const shape& s) { return s.transposed() and s.strides().back() != 1; }

static int32_t get_batch_stride_hip(const shape& s)
Expand Down
8 changes: 6 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/hip_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,24 @@ namespace gpu {

struct context;
void blas_shape_hip(const shape& s);
shape transpose_batch_hip(const shape& s, unsigned trans_batch);

template <class Op>
struct hip_gemm
{
Op op;
float alpha = 1;
float beta = 0;
unsigned trans_batch = 0;
int32_t solution_idx = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"),
f(self.beta, "beta"),
f(self.trans_batch, "trans_batch"),
f(self.solution_idx, "solution_idx")));
}

Expand Down Expand Up @@ -98,10 +102,10 @@ struct hip_gemm
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
return op_out_shape;
return transpose_batch_hip(op_out_shape, trans_batch);
}

return op.compute_shape(in_shapes);
return transpose_batch_hip(op.compute_shape(in_shapes), trans_batch);
}

argument
Expand Down

0 comments on commit b0072d9

Please sign in to comment.