@@ -283,7 +283,7 @@ struct BenchmarkRunnerGemm {
283
283
}
284
284
285
285
// / Initialize operands to be used in the GEMM and reference GEMM
286
- void initialize (const ProblemShapeType& problem_size) {
286
+ void initialize (::benchmark::State& state, const ProblemShapeType& problem_size) {
287
287
auto problem_shape_MNKL = cute::append<4 >(problem_size, 1 );
288
288
auto [M, N, K, L] = problem_shape_MNKL;
289
289
@@ -309,28 +309,31 @@ struct BenchmarkRunnerGemm {
309
309
}
310
310
}
311
311
312
- for (int i=0 ; i < count; i++) {
313
- block_A[i].reset (size_A);
314
- block_B[i].reset (size_B);
315
- block_C[i].reset (size_C);
316
- initialize_block (block_A[i], seed + i);
317
- initialize_block (block_B[i], seed + i);
318
- initialize_block (block_C[i], seed + i);
319
- if constexpr (epi_is_deeltactmul) {
320
- block_Aux[i].reset (size_C);
321
- initialize_block (block_Aux[i], seed + i);
312
+ try {
313
+ for (int i = 0 ; i < count; i++) {
314
+ block_A[i].reset (size_A);
315
+ block_B[i].reset (size_B);
316
+ block_C[i].reset (size_C);
317
+ initialize_block (block_A[i], seed + i);
318
+ initialize_block (block_B[i], seed + i);
319
+ initialize_block (block_C[i], seed + i);
320
+ if constexpr (epi_is_deeltactmul) {
321
+ block_Aux[i].reset (size_C);
322
+ initialize_block (block_Aux[i], seed + i);
323
+ }
322
324
}
323
- }
324
-
325
- block_D.reset (size_C);
326
- block_ref_D.reset (size_C);
327
325
326
+ block_D.reset (size_C);
327
+ block_ref_D.reset (size_C);
328
+ } catch (std::exception const &e) {
329
+ state.SkipWithError (e.what ());
330
+ }
328
331
}
329
332
330
333
void run (::benchmark::State& state, const GEMMOptions& options, const KernelHardwareInfo& hw_info) {
331
334
ProblemShapeType problem_size = ProblemShapeType{options.m , options.n , options.k , options.l };
332
335
333
- initialize (problem_size);
336
+ initialize (state, problem_size);
334
337
335
338
typename Gemm::GemmKernel::Arguments arguments = GemmConfiguration::defaultArguments ();
336
339
arguments.mode = gemm::GemmUniversalMode::kGemm ;
@@ -349,9 +352,13 @@ struct BenchmarkRunnerGemm {
349
352
size_t workspace_size = Gemm::get_workspace_size (arguments);
350
353
device_memory::allocation<uint8_t > workspace (workspace_size);
351
354
352
- gemm_op.can_implement (arguments);
355
+ if (gemm_op.can_implement (arguments) != cutlass::Status::kSuccess )
356
+ state.SkipWithError (" GEMM unable to implement given args." );
357
+
358
+ if (gemm_op.initialize (arguments, workspace.get ()) != cutlass::Status::kSuccess )
359
+ state.SkipWithError (" GEMM failed to initialize." );
353
360
354
- gemm_op. initialize (arguments, workspace. get ());
361
+ if (state. error_occurred ()) return ;
355
362
356
363
// Run the GEMM
357
364
gemm_op.run ();
0 commit comments