Skip to content

Commit

Permalink
Update finalize function for hipBLASLt to find tuned solution only wh…
Browse files Browse the repository at this point in the history
…en solution index is 0. (#3593)

Currently, finalize function runs tuning for all solution indices.
This PR modifies the function to run tuning only when
solution index is '0'. For all other solution indices, it runs validate.

This PR also fixes a bug in functionality of hipBLASLt validate function.
  • Loading branch information
ahsan-ca authored Nov 13, 2024
1 parent 495d3eb commit f1a1f9f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
22 changes: 15 additions & 7 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,9 @@ struct hip_gemm_impl
int32_t
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) // const
{
hipblasStatus_t check_valid(HIPBLAS_STATUS_SUCCESS);
auto common_args = create_hipblaslt_args_common(ctx, input_args, solution_idx);
check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args);
if(check_valid == HIPBLAS_STATUS_SUCCESS)
auto check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args, false);
if(check_valid != HIPBLAS_STATUS_SUCCESS)
{
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
return 0;
Expand Down Expand Up @@ -639,10 +638,19 @@ int32_t hip_gemm_finalize(context& ctx,
float beta,
int32_t solution_idx)
{
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
int32_t solution = gemm_item.tune(ctx, input_shapes);
hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx);
return solution;
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
if(solution_idx == 0)
{
solution_idx = gemm_item.tune(ctx, input_shapes);
hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx);
}
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different hipBLASLt version.
else
{
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
}
return solution_idx;
}

int32_t hip_gemm_default_solution(context& ctx,
Expand Down
11 changes: 8 additions & 3 deletions src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,20 @@ inline auto hipblaslt_invoke(F f, Ts... xs)
return status;
}

// Invoke a hipBLASLt call. If used to validate a call, set fatal_error = false to prevent
// throwing an exception on failure.
template <class F, class Pack, class... Ts>
auto hipblaslt_invoke(F f, Pack p, Ts... xs)
auto hipblaslt_invoke(F f, Pack p, Ts... xs, bool fatal_error = true)
{
return p([=](auto... ws) {
auto status = f(ws..., xs...);
if(status != HIPBLAS_STATUS_SUCCESS)
{
MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " +
std::to_string(status));
if(fatal_error)
{
MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " +
std::to_string(status));
}
}
return status;
});
Expand Down

0 comments on commit f1a1f9f

Please sign in to comment.