diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 69870c4455e..9ee0d005a15 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -632,20 +632,19 @@ int32_t hip_gemm_finalize(context& ctx, float beta, int32_t solution_idx) { - int solution; auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta); if(solution_idx == 0) { - solution = gemm_item.tune(ctx, input_shapes); + solution_idx = gemm_item.tune(ctx, input_shapes); hip_gemm_save_solution(ctx, output_shape, input_shapes, solution); } // If a tuned solution index is already given, don't tune again but validate - // in case the data was tuned with a different hipBLAS version. + // in case the data was tuned with a different hipBLASLt version. else { - solution = gemm_item.validate(ctx, input_shapes, solution_idx); + solution_idx = gemm_item.validate(ctx, input_shapes, solution_idx); } - return solution; + return solution_idx; } int32_t hip_gemm_default_solution(context& ctx,