Skip to content

Commit

Permalink
Update finalize function to run tune only when sol idx is 0.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca committed Nov 6, 2024
1 parent 624c8d3 commit 7180234
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,9 @@ struct hip_gemm_impl
int32_t
validate(context& ctx, const std::vector<argument>& input_args, int32_t solution_idx) // const
{
hipblasStatus_t check_valid(HIPBLAS_STATUS_SUCCESS);
auto common_args = create_hipblaslt_args_common(ctx, input_args, solution_idx);
check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args);
if(check_valid == HIPBLAS_STATUS_SUCCESS)
auto check_valid = hipblaslt_invoke(&hipblasLtMatmul, common_args);
if(check_valid != HIPBLAS_STATUS_SUCCESS)
{
std::cerr << "WARNING: tuned solution is invalid; reverting to default" << std::endl;
return 0;
Expand Down Expand Up @@ -633,9 +632,19 @@ int32_t hip_gemm_finalize(context& ctx,
float beta,
int32_t solution_idx)
{
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
int32_t solution = gemm_item.tune(ctx, input_shapes);
hip_gemm_save_solution(ctx, output_shape, input_shapes, solution_idx);
int solution;
auto gemm_item = hip_gemm_impl(output_shape, input_shapes, alpha, beta);
if(solution_idx == 0)
{
solution = gemm_item.tune(ctx, input_shapes);
hip_gemm_save_solution(ctx, output_shape, input_shapes, solution);
}
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different hipBLAS version.
else
{
solution = gemm_item.validate(ctx, input_shapes, solution_idx);
}
return solution;
}

Expand Down

0 comments on commit 7180234

Please sign in to comment.