36
36
#include " cutlass/gemm/device/gemm_universal_adapter.h"
37
37
#include " cutlass/gemm/collective/collective_mma.hpp"
38
38
#include " cutlass/util/GPU_Clock.hpp"
39
+ #include " cutlass/epilogue/fusion/operations.hpp"
39
40
40
41
#include " cutlass/util/host_tensor.h"
41
42
#include " cutlass/util/reference/host/tensor_fill.h"
47
48
#include " cutlass/util/reference/device/gemm_complex.h"
48
49
#include " cutlass/util/reference/device/tensor_compare.h"
49
50
#include " cutlass/util/reference/device/tensor_fill.h"
51
+ #include " cutlass/util/reference/device/tensor_silu.h"
50
52
51
53
#include < benchmark/benchmark.h>
52
54
@@ -169,6 +171,37 @@ struct BenchmarkRunnerGemm {
169
171
170
172
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
171
173
174
+ using FusionOp = typename Gemm::EpilogueOutputOp;
175
+
176
+ // TODO(codeplay): Epilogue detection here should be replaced w/ general solution (see other TODO)
177
+ using FusionSilu = cutlass::epilogue::fusion::LinCombEltAct<
178
+ cutlass::epilogue::thread::SiLu, ElementOutput, ElementCompute, ElementAccumulator,
179
+ ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
180
+
181
+ using FusionDeEltMul = cutlass::epilogue::fusion::LinCombDeEltAct<LayoutC, std::multiplies,
182
+ ElementOutput, ElementCompute>;
183
+ using FusionLinComb = epilogue::fusion::LinearCombination<
184
+ ElementOutput, ElementCompute, ElementAccumulator, ElementAccumulator,
185
+ FloatRoundStyle::round_to_nearest>;
186
+
187
+ // Epilogue used in ampere/gemm_configuration.hpp
188
+ using DefaultEpilogue = epilogue::collective::DefaultEpilogue<
189
+ float ,
190
+ cutlass::gemm::TagToStrideC_t<LayoutC>,
191
+ cutlass::gemm::TagToStrideC_t<LayoutC>,
192
+ epilogue::thread::LinearCombination<float , 1 >,
193
+ cutlass::gemm::EpilogueDefault>;
194
+
195
+ static constexpr bool epi_is_deeltactmul = std::is_same_v<FusionOp, FusionDeEltMul>;
196
+ static constexpr bool epi_is_silu = std::is_same_v<FusionOp, FusionSilu>;
197
+ static constexpr bool epi_is_lincomb = std::is_same_v<FusionOp, FusionLinComb>;
198
+ static constexpr bool epi_is_default = std::is_same_v<CollectiveEpilogue, DefaultEpilogue>;
199
+ static_assert (cute::is_base_of_v<cutlass::epilogue::fusion::FusionOperation, FusionOp> ||
200
+ epi_is_default,
201
+ " Failed to determine benchmark epilogue" );
202
+ static_assert (epi_is_default || epi_is_deeltactmul || epi_is_silu || epi_is_lincomb,
203
+ " Failed to determine benchmark epilogue" );
204
+
172
205
int32_t count;
173
206
174
207
//
@@ -188,6 +221,7 @@ struct BenchmarkRunnerGemm {
188
221
std::vector<DeviceAllocation<ElementC>> block_C;
189
222
DeviceAllocation<ElementOutput> block_D;
190
223
DeviceAllocation<ElementOutput> block_ref_D;
224
+ std::vector<DeviceAllocation<ElementOutput>> block_Aux;
191
225
192
226
BenchmarkRunnerGemm () : seed(0 ) {};
193
227
@@ -227,6 +261,20 @@ struct BenchmarkRunnerGemm {
227
261
cudaDeviceSynchronize ();
228
262
#endif
229
263
264
+ // TODO(codeplay): Replace this with a general solution (hook up to Testbed3x)
265
+ if constexpr (epi_is_silu) {
266
+ using TensorView = cutlass::TensorView<ElementOutput, LayoutD>;
267
+ for (int batch = 0 , offset = 0 ; batch < L; batch++, offset += M * N) {
268
+ cutlass::reference::device::TensorSiLu (TensorView (
269
+ block_ref_D.get () + offset, LayoutD::packed ({M, N}), cutlass::make_Coord (M, N)));
270
+ }
271
+ } else if constexpr (epi_is_deeltactmul) {
272
+ cutlass::reference::device::BlockElementwiseOp<std::multiplies>(
273
+ block_ref_D.get (), block_ref_D.get (), block_Aux[0 ].get (), block_D.size ());
274
+ }
275
+
276
+ syclcompat::wait ();
277
+
230
278
// Check if output from CUTLASS kernel and reference kernel are equal or not
231
279
bool passed = reference::device::BlockCompareEqual (
232
280
block_ref_D.get (), block_D.get (), block_D.size ());
@@ -256,6 +304,9 @@ struct BenchmarkRunnerGemm {
256
304
block_A.emplace_back ();
257
305
block_B.emplace_back ();
258
306
block_C.emplace_back ();
307
+ if constexpr (epi_is_deeltactmul) {
308
+ block_Aux.emplace_back ();
309
+ }
259
310
}
260
311
261
312
for (int i=0 ; i < count; i++) {
@@ -265,6 +316,10 @@ struct BenchmarkRunnerGemm {
265
316
initialize_block (block_A[i], seed + i);
266
317
initialize_block (block_B[i], seed + i);
267
318
initialize_block (block_C[i], seed + i);
319
+ if constexpr (epi_is_deeltactmul) {
320
+ block_Aux[i].reset (size_C);
321
+ initialize_block (block_Aux[i], seed + i);
322
+ }
268
323
}
269
324
270
325
block_D.reset (size_C);
@@ -284,6 +339,11 @@ struct BenchmarkRunnerGemm {
284
339
arguments.epilogue = {{options.alpha , options.beta }, block_C[0 ].get (), stride_C, block_D.get (), stride_D};
285
340
arguments.hw_info = hw_info;
286
341
342
+ if constexpr (epi_is_deeltactmul){
343
+ arguments.epilogue .thread .aux_ptr = block_Aux[0 ].get ();
344
+ arguments.epilogue .thread .dAux = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (options.m , options.n , options.l ));
345
+ }
346
+
287
347
Gemm gemm_op;
288
348
289
349
size_t workspace_size = Gemm::get_workspace_size (arguments);
@@ -352,6 +412,10 @@ struct BenchmarkRunnerGemm {
352
412
{{options.alpha , options.beta }, block_C[input_num].get (), stride_C, block_D.get (), stride_D},
353
413
hw_info
354
414
};
415
+ if constexpr (epi_is_deeltactmul){
416
+ arguments.epilogue .thread .aux_ptr = block_Aux[input_num].get ();
417
+ arguments.epilogue .thread .dAux = cutlass::make_cute_packed_stride (StrideD{}, cute::make_shape (options.m , options.n , options.l ));
418
+ }
355
419
gemm_op.initialize (arguments, workspace.get ());
356
420
state.ResumeTiming ();
357
421
@@ -370,18 +434,20 @@ struct BenchmarkRunnerGemm {
370
434
static void initialize_counters (::benchmark::State& state) {
371
435
state.counters [" avg_runtime_ms" ] = 0 ;
372
436
state.counters [" best_runtime_ms" ] = std::numeric_limits<double >::max ();
437
+ state.counters [" worst_runtime_ms" ] = -std::numeric_limits<double >::max ();
373
438
}
374
439
375
440
static void update_counters (::benchmark::State& state, double ms_elapsed) {
376
441
state.PauseTiming ();
377
442
state.counters [" total_runtime_ms" ] += ms_elapsed;
378
443
state.counters [" best_runtime_ms" ] = std::min<double >(state.counters [" best_runtime_ms" ], ms_elapsed);
444
+ state.counters [" worst_runtime_ms" ] = std::max<double >(state.counters [" worst_runtime_ms" ], ms_elapsed);
379
445
state.ResumeTiming ();
380
446
}
381
447
382
448
static void finalize_counters (::benchmark::State& state, double gflop, double mega_bytes_transferred) {
383
449
state.counters [" avg_runtime_ms" ] =
384
- state.counters [" total_runtime_ms" ] / static_cast <double >(state.iterations ());
450
+ ( state.counters [" total_runtime_ms" ] -state. counters [ " best_runtime_ms " ] - state. counters [ " worst_runtime_ms " ] ) / static_cast <double >(state.iterations () - 2 );
385
451
state.counters [" avg_tflops" ] = gflop / state.counters [" avg_runtime_ms" ];
386
452
state.counters [" avg_throughput" ] = mega_bytes_transferred / state.counters [" avg_runtime_ms" ];
387
453
state.counters [" best_tflop" ] = gflop / state.counters [" best_runtime_ms" ];
0 commit comments