Skip to content

Commit d736d1c

Browse files
joeatoddjiyang1011mehdi-golijiyang1011
authored
Define benchmarking input file for April release (#316)
This PR does 2 things: - move the benchmarking input files into their own dir - defines a new (incomplete) one for the april pytorch deadline --------- Co-authored-by: jiyang1011 <[email protected]> Co-authored-by: mehdi-goli <[email protected]> Co-authored-by: jiyang1011 <[email protected]@pvc125074:~/cutlass-tod$>
1 parent 1b62e2f commit d736d1c

File tree

8 files changed

+1388
-78
lines changed

8 files changed

+1388
-78
lines changed
File renamed without changes.

benchmarks/benchmark_runner.hpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "cutlass/gemm/device/gemm_universal_adapter.h"
3737
#include "cutlass/gemm/collective/collective_mma.hpp"
3838
#include "cutlass/util/GPU_Clock.hpp"
39+
#include "cutlass/epilogue/fusion/operations.hpp"
3940

4041
#include "cutlass/util/host_tensor.h"
4142
#include "cutlass/util/reference/host/tensor_fill.h"
@@ -47,6 +48,7 @@
4748
#include "cutlass/util/reference/device/gemm_complex.h"
4849
#include "cutlass/util/reference/device/tensor_compare.h"
4950
#include "cutlass/util/reference/device/tensor_fill.h"
51+
#include "cutlass/util/reference/device/tensor_silu.h"
5052

5153
#include <benchmark/benchmark.h>
5254

@@ -169,6 +171,37 @@ struct BenchmarkRunnerGemm {
169171

170172
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
171173

174+
using FusionOp = typename Gemm::EpilogueOutputOp;
175+
176+
// TODO(codeplay): Epilogue detection here should be replaced w/ general solution (see other TODO)
177+
using FusionSilu = cutlass::epilogue::fusion::LinCombEltAct<
178+
cutlass::epilogue::thread::SiLu, ElementOutput, ElementCompute, ElementAccumulator,
179+
ElementAccumulator, cutlass::FloatRoundStyle::round_to_nearest>;
180+
181+
using FusionDeEltMul = cutlass::epilogue::fusion::LinCombDeEltAct<LayoutC, std::multiplies,
182+
ElementOutput, ElementCompute>;
183+
using FusionLinComb = epilogue::fusion::LinearCombination<
184+
ElementOutput, ElementCompute, ElementAccumulator, ElementAccumulator,
185+
FloatRoundStyle::round_to_nearest>;
186+
187+
// Epilogue used in ampere/gemm_configuration.hpp
188+
using DefaultEpilogue = epilogue::collective::DefaultEpilogue<
189+
float,
190+
cutlass::gemm::TagToStrideC_t<LayoutC>,
191+
cutlass::gemm::TagToStrideC_t<LayoutC>,
192+
epilogue::thread::LinearCombination<float, 1>,
193+
cutlass::gemm::EpilogueDefault>;
194+
195+
static constexpr bool epi_is_deeltactmul = std::is_same_v<FusionOp, FusionDeEltMul>;
196+
static constexpr bool epi_is_silu = std::is_same_v<FusionOp, FusionSilu>;
197+
static constexpr bool epi_is_lincomb = std::is_same_v<FusionOp, FusionLinComb>;
198+
static constexpr bool epi_is_default = std::is_same_v<CollectiveEpilogue, DefaultEpilogue>;
199+
static_assert(cute::is_base_of_v<cutlass::epilogue::fusion::FusionOperation, FusionOp> ||
200+
epi_is_default,
201+
"Failed to determine benchmark epilogue");
202+
static_assert(epi_is_default || epi_is_deeltactmul || epi_is_silu || epi_is_lincomb,
203+
"Failed to determine benchmark epilogue");
204+
172205
int32_t count;
173206

174207
//
@@ -188,6 +221,7 @@ struct BenchmarkRunnerGemm {
188221
std::vector<DeviceAllocation<ElementC>> block_C;
189222
DeviceAllocation<ElementOutput> block_D;
190223
DeviceAllocation<ElementOutput> block_ref_D;
224+
std::vector<DeviceAllocation<ElementOutput>> block_Aux;
191225

192226
BenchmarkRunnerGemm() : seed(0) {};
193227

@@ -227,6 +261,20 @@ struct BenchmarkRunnerGemm {
227261
cudaDeviceSynchronize();
228262
#endif
229263

264+
// TODO(codeplay): Replace this with a general solution (hook up to Testbed3x)
265+
if constexpr (epi_is_silu) {
266+
using TensorView = cutlass::TensorView<ElementOutput, LayoutD>;
267+
for (int batch = 0, offset = 0; batch < L; batch++, offset += M * N) {
268+
cutlass::reference::device::TensorSiLu(TensorView(
269+
block_ref_D.get() + offset, LayoutD::packed({M, N}), cutlass::make_Coord(M, N)));
270+
}
271+
} else if constexpr (epi_is_deeltactmul) {
272+
cutlass::reference::device::BlockElementwiseOp<std::multiplies>(
273+
block_ref_D.get(), block_ref_D.get(), block_Aux[0].get(), block_D.size());
274+
}
275+
276+
syclcompat::wait();
277+
230278
// Check if output from CUTLASS kernel and reference kernel are equal or not
231279
bool passed = reference::device::BlockCompareEqual(
232280
block_ref_D.get(), block_D.get(), block_D.size());
@@ -256,6 +304,9 @@ struct BenchmarkRunnerGemm {
256304
block_A.emplace_back();
257305
block_B.emplace_back();
258306
block_C.emplace_back();
307+
if constexpr (epi_is_deeltactmul) {
308+
block_Aux.emplace_back();
309+
}
259310
}
260311

261312
for (int i=0; i < count; i++) {
@@ -265,6 +316,10 @@ struct BenchmarkRunnerGemm {
265316
initialize_block(block_A[i], seed + i);
266317
initialize_block(block_B[i], seed + i);
267318
initialize_block(block_C[i], seed + i);
319+
if constexpr (epi_is_deeltactmul) {
320+
block_Aux[i].reset(size_C);
321+
initialize_block(block_Aux[i], seed + i);
322+
}
268323
}
269324

270325
block_D.reset(size_C);
@@ -284,6 +339,11 @@ struct BenchmarkRunnerGemm {
284339
arguments.epilogue = {{options.alpha, options.beta}, block_C[0].get(), stride_C, block_D.get(), stride_D};
285340
arguments.hw_info = hw_info;
286341

342+
if constexpr(epi_is_deeltactmul){
343+
arguments.epilogue.thread.aux_ptr = block_Aux[0].get();
344+
arguments.epilogue.thread.dAux = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
345+
}
346+
287347
Gemm gemm_op;
288348

289349
size_t workspace_size = Gemm::get_workspace_size(arguments);
@@ -352,6 +412,10 @@ struct BenchmarkRunnerGemm {
352412
{{options.alpha, options.beta}, block_C[input_num].get(), stride_C, block_D.get(), stride_D},
353413
hw_info
354414
};
415+
if constexpr(epi_is_deeltactmul){
416+
arguments.epilogue.thread.aux_ptr = block_Aux[input_num].get();
417+
arguments.epilogue.thread.dAux = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
418+
}
355419
gemm_op.initialize(arguments, workspace.get());
356420
state.ResumeTiming();
357421

@@ -370,18 +434,20 @@ struct BenchmarkRunnerGemm {
370434
static void initialize_counters(::benchmark::State& state) {
371435
state.counters["avg_runtime_ms"] = 0;
372436
state.counters["best_runtime_ms"] = std::numeric_limits<double>::max();
437+
state.counters["worst_runtime_ms"] = -std::numeric_limits<double>::max();
373438
}
374439

375440
static void update_counters(::benchmark::State& state, double ms_elapsed) {
376441
state.PauseTiming();
377442
state.counters["total_runtime_ms"] += ms_elapsed;
378443
state.counters["best_runtime_ms"] = std::min<double>(state.counters["best_runtime_ms"], ms_elapsed);
444+
state.counters["worst_runtime_ms"] = std::max<double>(state.counters["worst_runtime_ms"], ms_elapsed);
379445
state.ResumeTiming();
380446
}
381447

382448
static void finalize_counters(::benchmark::State& state, double gflop, double mega_bytes_transferred) {
383449
state.counters["avg_runtime_ms"] =
384-
state.counters["total_runtime_ms"] / static_cast<double>(state.iterations());
450+
(state.counters["total_runtime_ms"] -state.counters["best_runtime_ms"] - state.counters["worst_runtime_ms"] ) / static_cast<double>(state.iterations() - 2);
385451
state.counters["avg_tflops"] = gflop / state.counters["avg_runtime_ms"];
386452
state.counters["avg_throughput"] = mega_bytes_transferred / state.counters["avg_runtime_ms"];
387453
state.counters["best_tflop"] = gflop / state.counters["best_runtime_ms"];

0 commit comments

Comments
 (0)