Skip to content

Commit

Permalink
The basic MLIR-based end-to-end GPU Tensor Core GEMM codegen. (#1199)
Browse files Browse the repository at this point in the history
It is the first version that provides the flow of the naive end-to-end GPU GEMM codegen.
  • Loading branch information
JamesTheZ authored Sep 11, 2023
1 parent 8cb5c2c commit da990f4
Show file tree
Hide file tree
Showing 28 changed files with 940 additions and 201 deletions.
4 changes: 2 additions & 2 deletions docs/developers/pass_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -556,11 +556,11 @@ specialized into:
scf.if %cond {
"lmhlo.fusion"() ( {
...
}) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_thread_per_block_hint = 256 : i32, disc_vectorize_or_tile_hint = 2 : i32} : () -> ()
}) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_cta_size_hint = 256 : i32, disc_vectorize_or_tile_hint = 2 : i32} : () -> ()
} else {
"lmhlo.fusion"() ( {
...
}) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2X_no_vectile", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_thread_per_block_hint = 256 : i32, disc_vectorize_or_tile_hint = 1 : i32} : () -> ()
}) {disc.device = "gpu", disc.fusion.name = "main_kRowReduction_reduce_0", disc.fusion.tag = "1b1rX_vectile2X_no_vectile", disc.fusion_type = "kRowReduction", disc_row_reduction_schedule_hint = 1 : i32, disc_cta_size_hint = 256 : i32, disc_vectorize_or_tile_hint = 1 : i32} : () -> ()
```

The different "disc_vectorize_or_tile_hint" attributes will guide the codegen passes to
Expand Down
25 changes: 23 additions & 2 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Transforms",
Expand Down Expand Up @@ -2241,6 +2242,25 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_erase_buffer_deallocation",
srcs = ["transforms/disc_erase_buffer_deallocation.cc"],
deps = [
":lmhlo_disc",
":disc_util",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:BufferizationTransforms",
],
alwayslink = 1,
)

cc_library(
name = "all_passes",
hdrs = [
Expand All @@ -2264,12 +2284,12 @@ cc_library(
":disc_dense_to_sparse",
":disc_convert_const_to_ral",
":disc_convert_fake_quant_op",
":disc_custom_call_rewriter",
":disc_cpu_map_parallel_loop",
":disc_custom_call_rewriter",
":disc_duplicate_computation_after_fusion",
":disc_duplicate_computation_for_fusion",
":disc_dynamic_slice_converter",
":disc_sparse_op_rewriter",
":disc_erase_buffer_deallocation",
":disc_flatten_memref_access",
":disc_for_loop_unroll_interleave",
":disc_fuse_splat_const",
Expand All @@ -2294,6 +2314,7 @@ cc_library(
":disc_shape_optimization",
":disc_shape_simplifier",
":disc_shape_to_std",
":disc_sparse_op_rewriter",
":disc_specialize_fusion_with_speculation",
":disc_std_bufferize",
":disc_stitch_fusion",
Expand Down
10 changes: 8 additions & 2 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ limitations under the License.
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h"
Expand Down Expand Up @@ -518,7 +519,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
tensorflow::ReadStringFromEnvVar("DISC_TRANSFORM_SCHEDULE_FILE", "",
&transform_schedule);
pm.addNestedPass<FuncOp>(disc_ral::createDiscTransformLegalizeToLoopPass(
gpu_enabled, transform_schedule));
gpu_enabled, transform_schedule, options.gpu_info.cc_major,
options.gpu_info.cc_minor));
}

pm.addNestedPass<FuncOp>(createCanonicalizerPass());
Expand Down Expand Up @@ -599,7 +601,7 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
// TODO: adopt tileSize from attributes of speculation pass with a
// wrapper of the original ParallelLoopTilingPass
pm.addNestedPass<FuncOp>(
disc_ral::createParallelLoopTilingPass({kThreadsRowReduction}, true));
disc_ral::createParallelLoopTilingPass({kCTASizeDefault}, true));
// pm.addNestedPass<FuncOp>(disc_ral::createMapParallelLoopsPass());
pm.addNestedPass<FuncOp>(mlir::createGpuMapParallelLoopsPass());

Expand Down Expand Up @@ -640,6 +642,10 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
kernelPm.addPass(createLoopInvariantCodeMotionPass());
kernelPm.addPass(createCSEPass());
}
kernelPm.addNestedPass<gpu::GPUFuncOp>(
disc_ral::createDiscEraseBufferDeallocationPass());
kernelPm.addNestedPass<gpu::GPUFuncOp>(
memref::createExpandStridedMetadataPass());
kernelPm.addPass(createConvertSCFToCFPass());
kernelPm.addPass(createLowerAffinePass());
kernelPm.addNestedPass<FuncOp>(createCanonicalizerPass());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} {
func.func @main(%arg0: tensor<256x128xf16>, %arg1: tensor<128x256xf16>) -> (tensor<256x256xf16>) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} {
%graph = tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<256x128xf16>, tensor<128x256xf16>) -> (tensor<256x256xf16>)
tf_executor.fetch %0 : tensor<256x256xf16>
}
return %graph : tensor<256x256xf16>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 0 : i32}} {
func.func @main(%arg0: tensor<256x128xf16>, %arg1: tensor<128x256xf16>) -> (tensor<256x256xf16>) attributes {tf.entry_function = {inputs = "{{INPUTS}}", outputs = "{{OUTPUTS}}", input_placements="{{INPUT_PLACEMENTS}}", output_placements="{{OUTPUT_PLACEMENTS}}"}} {
%graph = tf_executor.graph {
%0:2 = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<256x128xf16>, tensor<128x256xf16>) -> (tensor<256x256xf16>)
tf_executor.fetch %0 : tensor<256x256xf16>
}
return %graph : tensor<256x256xf16>
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.structured.match attributes {disc.transform.name = "dot_general"} in %arg0 : (!transform.any_op) -> !transform.any_op
%1:2 = split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%forall_op, %tiled_op = transform.structured.tile_to_forall_op %1#1 num_threads [] tile_sizes [128, 128](mapping = [#gpu.block<x>, #gpu.block<y>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1#0 into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%for_op, %splitted_op = transform.disc.split_reduction_serial %tiled_op by tile_sizes = [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%promoted_dot, %lhs_alloc, %rhs_alloc = transform.disc.promote_dot_operands %for_op [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%forall_op_0, %tiled_op_1 = transform.structured.tile_to_forall_op %promoted_dot num_threads [] tile_sizes [64, 64](mapping = [#gpu.warp<x>, #gpu.warp<y>]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%for_op_2, %splitted_op_3 = transform.disc.split_reduction_serial %tiled_op_1 by tile_sizes = [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%tiled_linalg_op, %loops:3 = transform.structured.tile %for_op_2[16, 8, 16] {interchange = [0, 1, 2]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.disc.apply_licm %arg0 : !transform.any_op
transform.disc.apply_dce %arg0 : !transform.any_op
transform.disc.apply_cse %arg0 : !transform.any_op
%2 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%3 = transform.disc.apply_patterns %2 {canonicalization} : (!transform.any_op) -> !transform.any_op
%4 = transform.structured.vectorize %3 {vectorize_padding} : (!transform.any_op) -> !transform.any_op
transform.disc.apply_dce %arg0 : !transform.any_op
transform.disc.apply_cse %arg0 : !transform.any_op
%5 = transform.disc.bufferize {target_gpu} %arg0 : (!transform.any_op) -> !transform.any_op
%6 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.erase_dealloc %6 : (!transform.any_op) -> ()
%7 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.transfer_write_zero_to_scf %7 : (!transform.any_op) -> ()
transform.disc.apply_dce %5 : !transform.any_op
transform.disc.apply_cse %5 : !transform.any_op
%8 = transform.structured.match ops{["scf.forall"]} attributes {mapping = [#gpu.block<x>, #gpu.block<y>]} in %5 : (!transform.any_op) -> !transform.any_op
%9 = transform.disc.forall_to_gpu_ctas %8 : (!transform.any_op) -> !transform.any_op
%10 = transform.structured.match ops{["scf.forall"]} attributes {mapping = [#gpu.warp<x>, #gpu.warp<y>]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.forall_to_gpu_warps %10 : (!transform.any_op) -> ()
transform.disc.apply_dce %5 : !transform.any_op
transform.disc.apply_cse %5 : !transform.any_op
%11 = transform.structured.match ops{["linalg.generic"]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.gmem_to_smem %11 : (!transform.any_op) -> ()
%12 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.vector.vector_to_mma_conversion %12 : (!transform.any_op) -> ()
transform.disc.apply_licm %5 : !transform.any_op
transform.disc.apply_dce %5 : !transform.any_op
transform.disc.apply_cse %5 : !transform.any_op
%13 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op
transform.disc.inline_and_convert_gpu_ids %13 : (!transform.any_op) -> ()
transform.disc.apply_licm %5 : !transform.any_op
transform.disc.apply_dce %5 : !transform.any_op
transform.disc.apply_cse %5 : !transform.any_op
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,22 @@ TEST(PackedMatmul, F32_768x3072_Using_Default_Schedule) {
/*profiling*/ true));
}

TEST(Matmul, F16_256x256x128_Using_Default_Schedule) {
EnvSetting setting = {{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
{"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}};
EnvSettingContext ctx(setting);
EXPECT_TRUE(feature_test_main(
/*mlir_file_path*/ c_ft_path +
"default_schedule_matmul_nn_s_256x256x128_f16.mlir",
/*backend_types*/ {BackendType::kCuda},
/*num_inputs*/ 2,
/*num_outputs*/ 1,
/*input_descriptors*/ {"256x128xf16_X", "128x256xf16_X"},
/*output_descriptors*/ {"f16_X"},
/*input_vals*/ {},
/*expected_output_vals*/ {},
/*profiling*/ true));
}

} // namespace mlir_test
53 changes: 38 additions & 15 deletions tao_compiler/mlir/disc/tests/disc-transform/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static bool init_threads = []() {
TEST(SimpleTest, MatMulF32_11x13x12) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}},
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}};
EnvSettingContext ctx(setting);
EXPECT_TRUE(feature_test_main(
Expand All @@ -48,7 +48,7 @@ TEST(SimpleTest, MatMulF32_11x13x12) {
TEST(SimpleTest, MatMulF32_111x131x121) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}},
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_schedule.mlir", false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}}};
EnvSettingContext ctx(setting);
EXPECT_TRUE(feature_test_main(
Expand All @@ -63,7 +63,8 @@ TEST(SimpleTest, MatMulF32_111x131x121) {
TEST(SimpleTest, MatMulF32_304x1024x256) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}},
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
{"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}};
Expand All @@ -83,7 +84,8 @@ TEST(SimpleTest, MatMulF32_304x1024x256) {
TEST(SimpleTest, MatMulF32_1024x1024x1024) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir", false}},
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
{"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}};
Expand All @@ -103,7 +105,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024) {
TEST(SimpleTest, MatMulF32_304x1024x256_2) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -124,7 +126,7 @@ TEST(SimpleTest, MatMulF32_304x1024x256_2) {
TEST(SimpleTest, MatMulF32_1024x1024x1024_2) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_2.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -145,7 +147,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_2) {
TEST(SimpleTest, MatMulF32_304x256x256_3) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -166,7 +168,7 @@ TEST(SimpleTest, MatMulF32_304x256x256_3) {
TEST(SimpleTest, MatMulF32_304x512x256_3) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -187,7 +189,7 @@ TEST(SimpleTest, MatMulF32_304x512x256_3) {
TEST(SimpleTest, MatMulF32_304x1024x256_3) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -208,7 +210,7 @@ TEST(SimpleTest, MatMulF32_304x1024x256_3) {
TEST(SimpleTest, MatMulF32_304x1024x512_3) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -229,7 +231,7 @@ TEST(SimpleTest, MatMulF32_304x1024x512_3) {
TEST(SimpleTest, MatMulF32_1024x1024x1024_3) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_3.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -250,7 +252,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_3) {
TEST(SimpleTest, MatMulF32_304x1024x512_4) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -271,7 +273,7 @@ TEST(SimpleTest, MatMulF32_304x1024x512_4) {
TEST(SimpleTest, MatMulF32_1024x1024x1024_4) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -292,7 +294,7 @@ TEST(SimpleTest, MatMulF32_1024x1024x1024_4) {
TEST(SimpleTest, MatMulF32_1026x1024x1024_4) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_4.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -313,7 +315,7 @@ TEST(SimpleTest, MatMulF32_1026x1024x1024_4) {
TEST(SimpleTest, MatMulF32_304x1024x512_5) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::" + c_ft_path + "matmul_nn_d_f32_large_schedule_5.mlir",
{"kGEMM::CPU:" + c_ft_path + "matmul_nn_d_f32_large_schedule_5.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"ENABLE_AARCH64_SCHEDUELS", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
Expand All @@ -331,4 +333,25 @@ TEST(SimpleTest, MatMulF32_304x1024x512_5) {
/*profiling*/ true));
}

TEST(SimpleTest, MatMulF16_GPU_256x256x128) {
EnvSetting setting = {
{"DISC_TRANSFORM_SCHEDULE_FILE",
{"kGEMM::GPU:" + c_ft_path + "matmul_nn_s_f16_gpu_schedule.mlir",
false}},
{"DISC_ENABLE_TRANSFORM_SCHEDULE", {"1", false}},
{"DISC_ENABLE_SHAPE_CONSTRAINT_IR", {"1", false}},
{"DISC_MEM_INTENSIVE_OPT_EXPERIMENTAL", {"0", false}}};
EnvSettingContext ctx(setting);
EXPECT_TRUE(feature_test_main(
/*mlir_file_path*/ c_ft_path + "matmul_nn_s_256x256x128_f16.mlir",
/*backend_types*/ {BackendType::kCuda},
/*num_inputs*/ 2,
/*num_outputs*/ 1,
/*input_descriptors*/ {"256x128xf16_X", "128x256xf16_X"},
/*output_descriptors*/ {"f16_X"},
/*input_vals*/ {},
/*expected_output_vals*/ {},
/*profiling*/ true));
}

} // namespace mlir_test
Loading

0 comments on commit da990f4

Please sign in to comment.