Skip to content

Commit f1a1f9f

Browse files
authored
Update finalize function for hipBLASLt to find tuned solution only when solution index is 0. (#3593)
Currently, finalize function runs tuning for all solution indices. This PR modifies the function to run tuning only when solution index is '0'. For all other solution indices, it runs validate. This PR also fixes a bug in functionality of hipBLASLt validate function.
1 parent 495d3eb commit f1a1f9f

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

src/targets/gpu/hip_gemm_impl.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,9 @@ struct hip_gemm_impl
468468
int32_t
469469
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) // const
470470
{
471-
hipblasStatus_t check_valid(HIPBLAS_STATUS_SUCCESS);
472471
auto common_args = create_hipblaslt_args_common(ctx, input_args, solution_idx);
473-
check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args);
474-
if(check_valid == HIPBLAS_STATUS_SUCCESS)
472+
auto check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args, false);
473+
if(check_valid != HIPBLAS_STATUS_SUCCESS)
475474
{
476475
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
477476
return 0;
@@ -639,10 +638,19 @@ int32_t hip_gemm_finalize(context& ctx,
639638
float beta,
640639
int32_t solution_idx)
641640
{
642-
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
643-
int32_t solution = gemm_item.tune(ctx, input_shapes);
644-
hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx);
645-
return solution;
641+
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
642+
if(solution_idx == 0)
643+
{
644+
solution_idx = gemm_item.tune(ctx, input_shapes);
645+
hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx);
646+
}
647+
// If a tuned solution index is already given, don't tune again but validate
648+
// in case the data was tuned with a different hipBLASLt version.
649+
else
650+
{
651+
solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx);
652+
}
653+
return solution_idx;
646654
}
647655

648656
int32_t hip_gemm_default_solution(context& ctx,

src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,20 @@ inline auto hipblaslt_invoke(F f, Ts... xs)
7575
return status;
7676
}
7777

78+
// Invoke a hipBLASLt call. If used to validate a call, set fatal_error = false to prevent
79+
// throwing an exception on failure.
7880
template <class F, class Pack, class... Ts>
79-
auto hipblaslt_invoke(F f, Pack p, Ts... xs)
81+
auto hipblaslt_invoke(F f, Pack p, Ts... xs, bool fatal_error = true)
8082
{
8183
return p([=](auto... ws) {
8284
auto status = f(ws..., xs...);
8385
if(status != HIPBLAS_STATUS_SUCCESS)
8486
{
85-
MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " +
86-
std::to_string(status));
87+
if(fatal_error)
88+
{
89+
MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " +
90+
std::to_string(status));
91+
}
8792
}
8893
return status;
8994
});

0 commit comments

Comments
 (0)