Skip to content

Commit da9c63f

Browse files
joeatoddFMarno
andauthored
Catch various errors in benchmark execution (#334)
Modify the `benchmark_runner` to skip when memory allocation fails (e.g. large problem size on BMG). Also check the `can_implement` and `initialize` methods. --------- Co-authored-by: Finlay <[email protected]>
1 parent 235bec3 commit da9c63f

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

benchmarks/benchmark_runner.hpp

+25-18
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ struct BenchmarkRunnerGemm {
283283
}
284284

285285
/// Initialize operands to be used in the GEMM and reference GEMM
286-
void initialize(const ProblemShapeType& problem_size) {
286+
void initialize(::benchmark::State& state, const ProblemShapeType& problem_size) {
287287
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
288288
auto [M, N, K, L] = problem_shape_MNKL;
289289

@@ -309,28 +309,31 @@ struct BenchmarkRunnerGemm {
309309
}
310310
}
311311

312-
for (int i=0; i < count; i++) {
313-
block_A[i].reset(size_A);
314-
block_B[i].reset(size_B);
315-
block_C[i].reset(size_C);
316-
initialize_block(block_A[i], seed + i);
317-
initialize_block(block_B[i], seed + i);
318-
initialize_block(block_C[i], seed + i);
319-
if constexpr (epi_is_deeltactmul) {
320-
block_Aux[i].reset(size_C);
321-
initialize_block(block_Aux[i], seed + i);
312+
try {
313+
for (int i = 0; i < count; i++) {
314+
block_A[i].reset(size_A);
315+
block_B[i].reset(size_B);
316+
block_C[i].reset(size_C);
317+
initialize_block(block_A[i], seed + i);
318+
initialize_block(block_B[i], seed + i);
319+
initialize_block(block_C[i], seed + i);
320+
if constexpr (epi_is_deeltactmul) {
321+
block_Aux[i].reset(size_C);
322+
initialize_block(block_Aux[i], seed + i);
323+
}
322324
}
323-
}
324-
325-
block_D.reset(size_C);
326-
block_ref_D.reset(size_C);
327325

326+
block_D.reset(size_C);
327+
block_ref_D.reset(size_C);
328+
} catch (std::exception const &e) {
329+
state.SkipWithError(e.what());
330+
}
328331
}
329332

330333
void run(::benchmark::State& state, const GEMMOptions& options, const KernelHardwareInfo& hw_info) {
331334
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
332335

333-
initialize(problem_size);
336+
initialize(state, problem_size);
334337

335338
typename Gemm::GemmKernel::Arguments arguments = GemmConfiguration::defaultArguments();
336339
arguments.mode = gemm::GemmUniversalMode::kGemm;
@@ -349,9 +352,13 @@ struct BenchmarkRunnerGemm {
349352
size_t workspace_size = Gemm::get_workspace_size(arguments);
350353
device_memory::allocation<uint8_t> workspace(workspace_size);
351354

352-
gemm_op.can_implement(arguments);
355+
if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess)
356+
state.SkipWithError("GEMM unable to implement given args.");
357+
358+
if (gemm_op.initialize(arguments, workspace.get()) != cutlass::Status::kSuccess)
359+
state.SkipWithError("GEMM failed to initialize.");
353360

354-
gemm_op.initialize(arguments, workspace.get());
361+
if (state.error_occurred()) return;
355362

356363
// Run the GEMM
357364
gemm_op.run();

0 commit comments

Comments
 (0)