Skip to content

Commit

Permalink
normalize standard input shapes to hipblaslt gemms (#3712)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca authored Dec 14, 2024
1 parent f56b1b4 commit de20bd0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ struct find_hipblas_gemm_pointwise : gemm_pointwise
shape s = c_ins->get_shape();
// const-fold input if not standard shape
// Updated for a case where "standard" shape has out-of-sequence strides
if(not s.standard() or s.normalize_standard() != s)
if(not s.standard())
{
auto c = make_op("contiguous");
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
Expand Down
7 changes: 4 additions & 3 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ hipDataType get_type_hipblas(shape::type_t type)
MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!");
}

void blas_shape_hip(const shape& s)
void blas_shape_hip(const shape& in_shape)
{
if(s.lens().size() < 2)
if(in_shape.lens().size() < 2)
return;
auto s = in_shape.normalize_standard();
if(std::none_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(std::any_of(s.strides().end() - 2, s.strides().end(), [](auto i) { return i == 0; }))
Expand Down Expand Up @@ -669,7 +670,7 @@ void hip_gemm_compute(context& ctx,
std::transform(args.begin(),
args.end(),
std::back_inserter(input_shapes),
[](const argument& x) { return x.get_shape(); });
[](const argument& x) { return x.get_shape().normalize_standard(); });
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
gemm_item.run(ctx, args, solution_idx);
}
Expand Down

0 comments on commit de20bd0

Please sign in to comment.