diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index f5ec898d8d5..c207e7c5a82 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -462,10 +462,9 @@ struct hip_gemm_impl int32_t validate(context& ctx, const std::vector& input_args, int32_t solution_idx) // const { - hipblasStatus_t check_valid(HIPBLAS_STATUS_SUCCESS); auto common_args = create_hipblaslt_args_common(ctx, input_args, solution_idx); - check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args); - if(check_valid == HIPBLAS_STATUS_SUCCESS) + auto check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args, false); + if(check_valid != HIPBLAS_STATUS_SUCCESS) { std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl; return 0; @@ -633,10 +632,19 @@ int32_t hip_gemm_finalize(context& ctx, float beta, int32_t solution_idx) { - auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); - int32_t solution = gemm_item.tune(ctx, input_shapes); - hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx); - return solution; + auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); + if(solution_idx == 0) + { + solution_idx = gemm_item.tune(ctx, input_shapes); + hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx); + } + // If a tuned solution index is already given, don't tune again but validate + // in case the data was tuned with a different hipBLASLt version. + else + { + solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); + } + return solution_idx; } int32_t hip_gemm_default_solution(context& ctx, diff --git a/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp b/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp index 8b9ec2cef63..49d41bf4dcd 100644 --- a/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp @@ -75,15 +75,20 @@ inline auto hipblaslt_invoke(F f, Ts... xs) return status; } +// Invoke a hipBLASLt call. If used to validate a call, set fatal_error = false to prevent +// throwing an exception on failure. template -auto hipblaslt_invoke(F f, Pack p, Ts... xs) +auto hipblaslt_invoke(F f, Pack p, Ts... xs, bool fatal_error = true) { return p([=](auto... ws) { auto status = f(ws..., xs...); if(status != HIPBLAS_STATUS_SUCCESS) { - MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " + - std::to_string(status)); + if(fatal_error) + { + MIGRAPHX_THROW("hipblaslt_invoke: hipBlasLt call failed with status " + + std::to_string(status)); + } } return status; });