From f58d7ce62e0891b654eb3f456672c96a675b6d5f Mon Sep 17 00:00:00 2001 From: "Gusthinna Waduge, Charitha Saumya" Date: Sat, 18 Nov 2023 00:17:55 +0000 Subject: [PATCH] XeTile GEMM e2e test cases demonstrating its capabilites --- include/imex/Dialect/XeTile/IR/XeTileAttrs.td | 19 +- .../sg_gemm_1kx1kx1k_bf16_bf16_bf32.mlir | 148 +++++++++ .../XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir | 148 +++++++++ .../XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir | 143 +++++++++ .../sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir | 282 ++++++++++++++++++ .../XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir | 172 +++++++++++ .../XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir | 172 +++++++++++ .../XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir | 171 +++++++++++ .../XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir | 170 +++++++++++ 9 files changed, 1423 insertions(+), 2 deletions(-) create mode 100644 test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_bf32.mlir create mode 100644 test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir create mode 100644 test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir create mode 100644 test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir create mode 100644 test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir create mode 100644 test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir create mode 100644 test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir create mode 100644 test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir diff --git a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td index 3bb53217f..a23fdbef5 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td +++ b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td @@ -68,12 +68,12 @@ def XeTile_WorkGroupMapAttr : XeTile_Attr<"WorkGroupMap", "wg_map"> { def XeTile_XeMapAttr : XeTile_Attr<"XeMap", "xe_map"> { let parameters = (ins - XeTile_WorkGroupMapAttr:$wg, + OptionalParameter<"imex::xetile::WorkGroupMapAttr">:$wg, XeTile_SubGroupMapAttr:$sg ); let assemblyFormat = "`<` struct(params) `>`"; let builders = [ - AttrBuilder<(ins "llvm::ArrayRef":$mma_block_size, + AttrBuilder<(ins "llvm::ArrayRef":$mma_block_size, "llvm::ArrayRef":$wi_layout, "llvm::ArrayRef":$wi_data, "llvm::ArrayRef":$sg_layout, @@ -82,6 +82,7 @@ def XeTile_XeMapAttr : XeTile_Attr<"XeMap", "xe_map"> { return $_get($_ctxt, WorkGroupMapAttr::get($_ctxt, sg_layout, sg_data), SubGroupMapAttr::get($_ctxt, mma_block_size, wi_layout, wi_data)) ; }]>, + // building XeMap without mma blocks size AttrBuilder<(ins "llvm::ArrayRef":$wi_layout, "llvm::ArrayRef":$wi_data, "llvm::ArrayRef":$sg_layout, @@ -89,7 +90,21 @@ def XeTile_XeMapAttr : XeTile_Attr<"XeMap", "xe_map"> { [{ return $_get($_ctxt, WorkGroupMapAttr::get($_ctxt, sg_layout, sg_data), SubGroupMapAttr::get($_ctxt, wi_layout, wi_data)) ; + }]>, + // building XeMap without sub group map + AttrBuilder<(ins "llvm::ArrayRef":$mma_block_size, + "llvm::ArrayRef":$wi_layout, + "llvm::ArrayRef":$wi_data), + [{ + return $_get($_ctxt, WorkGroupMapAttr(), SubGroupMapAttr::get($_ctxt, mma_block_size, wi_layout, wi_data)); + }]>, + // building XeMap without sub group map and mma block size + AttrBuilder<(ins "llvm::ArrayRef":$wi_layout, + "llvm::ArrayRef":$wi_data), + [{ + return $_get($_ctxt, WorkGroupMapAttr(), SubGroupMapAttr::get($_ctxt, wi_layout, wi_data)) ; }]> + ]; } diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_bf32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_bf32.mlir new file mode 100644 index 000000000..dab246e38 --- /dev/null +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_bf16_bf16_bf32.mlir @@ -0,0 +1,148 @@ +// TODO: Add imex-runner commands +// RUN: + +// NOTES : +// This example assumes one subgroup per one workgroup and the kernel specifies the computation +// done by a single subgroup. + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> + memref.copy %A, %A_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xbf16> + memref.copy %B, %B_gpu : memref<1024x1024xbf16> to memref<1024x1024xbf16> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> + memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xbf16>, %B_gpu : memref<1024x1024xbf16>, %C_gpu : memref<1024x1024xf32>) + gpu.dealloc %A_gpu : memref<1024x1024xbf16> + gpu.dealloc %B_gpu : memref<1024x1024xbf16> + return %C_gpu : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<1024x1024xbf16>, %B: memref<1024x1024xbf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xbf16> -> !xetile.tile<16x32xbf16> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xbf16> -> !xetile.tile<32x32xbf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xbf16> -> vector<16x32xbf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xbf16> -> vector<32x32xbf16> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<16x32xbf16>, vector<32x32xbf16>, vector<16x32xf32> -> vector<16x32xf32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] + : !xetile.tile<16x32xbf16>, index, index -> !xetile.tile<16x32xbf16> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] + : !xetile.tile<32x32xbf16>, index, index -> !xetile.tile<32x32xbf16> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<16x32xbf16>, !xetile.tile<32x32xbf16>, vector<16x32xf32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cf_0 = arith.constant 0.0 : bf16 + %cf_1 = arith.constant 1.0 : bf16 + %A = memref.alloc() : memref<1024x1024xbf16> + %B = memref.alloc() : memref<1024x1024xbf16> + %C = memref.alloc() : memref<1024x1024xf32> + %C_ref = memref.alloc() : memref<1024x1024xf32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to bf16 + memref.store %val, %A[%i, %j] : memref<1024x1024xbf16> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<1024x1024xbf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<1024x1024xbf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xbf16> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xbf16> + %t = arith.mulf %a_val, %b_val : bf16 + %t_cast = arith.extf %t : bf16 to f32 + %c_sum = arith.addf %t_cast, %c_partial : f32 + scf.yield %c_sum : f32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xbf16>, memref<1024x1024xbf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + // %cast = memref.cast %B : memref<1024x1024xbf16> to memref<*xbf16> + // call @printMemrefbf16(%cast) : (memref<*xbf16>) -> () + %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> + // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> + // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> + // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<1024x1024xbf16> + memref.dealloc %B : memref<1024x1024xbf16> + memref.dealloc %C : memref<1024x1024xf32> + memref.dealloc %C_ref : memref<1024x1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir new file mode 100644 index 000000000..713fb6d73 --- /dev/null +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_f16_f16_f32.mlir @@ -0,0 +1,148 @@ +// TODO: Add imex-runner commands +// RUN: + +// NOTES : +// This example assumes one subgroup per one workgroup and the kernel specifies the computation +// done by a single subgroup. + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> + memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) + gpu.dealloc %A_gpu : memref<1024x1024xf16> + gpu.dealloc %B_gpu : memref<1024x1024xf16> + return %C_gpu : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16 : index + %n = arith.muli %block_id_y, %c32 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<16x32xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xf32> -> vector<16x32xf32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<16x32xf16> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c32 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x32xf16> -> vector<16x32xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c32] + : !xetile.tile<16x32xf16>, index, index -> !xetile.tile<16x32xf16> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c32, %c0] + : !xetile.tile<32x32xf16>, index, index -> !xetile.tile<32x32xf16> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<16x32xf16>, !xetile.tile<32x32xf16>, vector<16x32xf32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile: vector<16x32xf32>, !xetile.tile<16x32xf32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<1024x1024xf16> + %B = memref.alloc() : memref<1024x1024xf16> + %C = memref.alloc() : memref<1024x1024xf32> + %C_ref = memref.alloc() : memref<1024x1024xf32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> + %t = arith.mulf %a_val, %b_val : f16 + %t_cast = arith.extf %t : f16 to f32 + %c_sum = arith.addf %t_cast, %c_partial : f32 + scf.yield %c_sum : f32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + // %cast = memref.cast %B : memref<1024x1024xf16> to memref<*xf16> + // call @printMemrefF16(%cast) : (memref<*xf16>) -> () + %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> + // call @printMemrefF32(%cast_C) : (memref<*xf32>) -> () + // call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> () + // %C_row_0 = memref.subview %2[0, 0][1, 1024][1, 1] : memref<1024x1024xf32> to memref<1x1024xf32> + // %C_row_0_cast = memref.cast %C_row_0 : memref<1x1024xf32> to memref<*xf32> + // call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> () + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<1024x1024xf16> + memref.dealloc %B : memref<1024x1024xf16> + memref.dealloc %C : memref<1024x1024xf32> + memref.dealloc %C_ref : memref<1024x1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir new file mode 100644 index 000000000..559037962 --- /dev/null +++ b/test/Integration/Dialect/XeTile/sg_gemm_1kx1kx1k_i8_i8_i32.mlir @@ -0,0 +1,143 @@ +// TODO: Add imex-runner commands +// RUN: + +// NOTES : +// This example assumes one subgroup per one workgroup and the kernel specifies the computation +// done by a single subgroup. + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> + memref.copy %A, %A_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> + memref.copy %B, %B_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> + memref.copy %C, %C_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%A_gpu : memref<1024x1024xi8>, %B_gpu : memref<1024x1024xi8>, %C_gpu : memref<1024x1024xi32>) + gpu.dealloc %A_gpu : memref<1024x1024xi8> + gpu.dealloc %B_gpu : memref<1024x1024xi8> + return %C_gpu : memref<1024x1024xi32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c16: index + %n = arith.muli %block_id_y, %c32: index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<16x32xi32> -> vector<16x32xi32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xi8> -> !xetile.tile<16x64xi8> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xi8> -> !xetile.tile<64x32xi8> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c64 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<16x64xi8>, !xetile.tile<64x32xi8>, vector<16x32xi32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<16x64xi8> -> vector<16x64xi8> + %b_value = xetile.load_tile %b_tile : !xetile.tile<64x32xi8> -> vector<64x32xi8> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<16x64xi8>, vector<64x32xi8>, vector<16x32xi32> -> vector<16x32xi32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c64] + : !xetile.tile<16x64xi8>, index, index -> !xetile.tile<16x64xi8> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c64, %c0] + : !xetile.tile<64x32xi8>, index, index -> !xetile.tile<64x32xi8> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<16x64xi8>, !xetile.tile<64x32xi8>, vector<16x32xi32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile : vector<16x32xi32>, !xetile.tile<16x32xi32> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %ci_0 = arith.constant 0 : i8 + %ci_1 = arith.constant 1 : i8 + %A = memref.alloc() : memref<1024x1024xi8> + %B = memref.alloc() : memref<1024x1024xi8> + %C = memref.alloc() : memref<1024x1024xi32> + %C_ref = memref.alloc() : memref<1024x1024xi32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %val = index.castu %j : index to i8 + memref.store %val, %A[%i, %j] : memref<1024x1024xi8> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %ci_1, %B[%i, %j] : memref<1024x1024xi8> + } else { + memref.store %ci_0, %B[%i, %j] : memref<1024x1024xi8> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_i32 = arith.constant 0: i32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_i32, %C[%i, %j] : memref<1024x1024xi32> + memref.store %c0_i32, %C_ref[%i, %j] : memref<1024x1024xi32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xi32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> i32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xi8> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xi8> + %a_val_i32 = arith.extui %a_val : i8 to i32 + %b_val_i32 = arith.extui %b_val : i8 to i32 + %t = arith.muli %a_val_i32, %b_val_i32 : i32 + %c_sum = arith.addi %t, %c_partial : i32 + scf.yield %c_sum : i32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xi32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xi8>, memref<1024x1024xi8>, memref<1024x1024xi32>) -> memref<1024x1024xi32> + %cast_C = memref.cast %2 : memref<1024x1024xi32> to memref<*xi32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xi32> to memref<*xi32> + + call @printAllcloseI32(%cast_C, %cast_C_ref) : (memref<*xi32>, memref<*xi32>) -> () + memref.dealloc %A : memref<1024x1024xi8> + memref.dealloc %B : memref<1024x1024xi8> + memref.dealloc %C : memref<1024x1024xi32> + memref.dealloc %C_ref : memref<1024x1024xi32> + return + } + func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefI8(memref<*xi8>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseI32(memref<*xi32>, memref<*xi32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir new file mode 100644 index 000000000..5db722d3d --- /dev/null +++ b/test/Integration/Dialect/XeTile/sg_gemm_2x2_1kx1kx1k_f16_f16_f32.mlir @@ -0,0 +1,282 @@ +// TODO: Add run commands +// RUN: + +// NOTES: +// This example assumes 2x2 subgroups per one workgroup and the kernel specifies the computation +// done by a single subgroup. This shows the result of lowering wg_gemm_1kx1kx1k_f16_f16_f32 example +// assuming the following layout maps. +// +// #wg_map_a = #xetile.wg_map +// #xe_map_a = #xetile.xe_map +// +// #wg_map_b = #xetile.wg_map +// #xe_map_b = #xetile.xe_map +// +// #wg_map_c = #xetile.wg_map +// #xe_map_c = #xetile.xe_map + + + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> + memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c2, %c2, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) + gpu.dealloc %A_gpu : memref<1024x1024xf16> + gpu.dealloc %B_gpu : memref<1024x1024xf16> + return %C_gpu : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c128 : index + %n = arith.muli %block_id_y, %c128 : index + + // get linear sub group id + %sg_id = gpu.subgroup_id : index + // get the x, y cordinate of this linear id assuming [2, 2] coord system + %c2 = arith.constant 2 : index + %sg_coord_x = index.floordivs %sg_id, %c2 + %sg_coord_y = index.and %sg_id, %c1 + + // each subgroup in the [2, 2] subgroups needs to update four 32x32 C sub-tiles + // that are arranged in round robin fashin according to SG coords + // | (0,0) | (0,1) | (0,0) | (0,1) | + // | (1,0) | (1,1) | (1,0) | (1,1) | + // | (0,0) | (0,1) | (0,0) | (0,1) | + // | (1,0) | (1,1) | (1,0) | (1,1) | + // first calculate the offset into the first SG sub-tile + %C_sg_tile_offset_x = index.mul %c32, %sg_coord_x + %C_sg_tile_offset_y = index.mul %c32, %sg_coord_y + + // C sub tiles + // global offset for sub tile 1 for this SG + %global_offset_slice0_x = index.add %m, %C_sg_tile_offset_x + %global_offset_slice0_y = index.add %n, %C_sg_tile_offset_y + // global offset for sub tile 2 for this SG (shift 64 in x) + %global_offset_slice1_x = index.add %global_offset_slice0_x, %c64 + %global_offset_slice1_y = index.add %global_offset_slice0_y, %c0 + // global offset for sub tile 3 for this SG (shift 64 in y) + %global_offset_slice2_x = index.add %global_offset_slice0_x, %c0 + %global_offset_slice2_y = index.add %global_offset_slice0_y, %c64 + // global offset for sub tile 4 for this SG (shift 64 in x and y) + %global_offset_slice3_x = index.add %global_offset_slice0_x, %c64 + %global_offset_slice3_y = index.add %global_offset_slice0_y, %c64 + + // intialize C sub tiles and load them + %c_init_subtile0 = xetile.init_tile %C[%global_offset_slice0_x, %global_offset_slice0_y] : memref<1024x1024xf32> + -> !xetile.tile<32x32xf32> + %c_init_value0 = xetile.load_tile %c_init_subtile0 : !xetile.tile<32x32xf32> + -> vector<32x32xf32> + %c_init_subtile1 = xetile.init_tile %C[%global_offset_slice1_x, %global_offset_slice1_y] : memref<1024x1024xf32> + -> !xetile.tile<32x32xf32> + %c_init_value1 = xetile.load_tile %c_init_subtile1 : !xetile.tile<32x32xf32> + -> vector<32x32xf32> + %c_init_subtile2 = xetile.init_tile %C[%global_offset_slice2_x, %global_offset_slice2_y] : memref<1024x1024xf32> + -> !xetile.tile<32x32xf32> + %c_init_value2 = xetile.load_tile %c_init_subtile2 : !xetile.tile<32x32xf32> + -> vector<32x32xf32> + %c_init_subtile3 = xetile.init_tile %C[%global_offset_slice3_x, %global_offset_slice3_y] : memref<1024x1024xf32> + -> !xetile.tile<32x32xf32> + %c_init_value3 = xetile.load_tile %c_init_subtile2 : !xetile.tile<32x32xf32> + -> vector<32x32xf32> + + // for A, each subgroup need to load two 32x128 subtiles. The access arrangement is as follows + // | (0,0), (0,1)| + // | (1,0), (1,1)| + // | (0,0), (0,1)| + // | (1,0), (1,1)| + + // calculate the initial offset in x dim for this sg + %a_init_offset = index.mul %sg_coord_x, %c32 + + // x offsets for A subtiles + %a_subtile0_x = index.add %m, %a_init_offset + %a_subtile1_x = index.add %a_subtile0_x, %c64 + + // init A subtiles + %a_init_subtile0 = xetile.init_tile %A[%a_subtile0_x, %c0] : memref<1024x1024xf16> + -> !xetile.tile<32x128xf16> + %a_init_subtile1 = xetile.init_tile %A[%a_subtile1_x, %c0] : memref<1024x1024xf16> + -> !xetile.tile<32x128xf16> + + // for B, each subgroup need to load two 128x32 subtiles. The access arrangement is as follows + // | (0,0) | (0,1) | (0,0) | (0, 1) | + // | (1,0) | (1,1) | (1,0) | (1, 1) | + + // calculate the initial offset along y dim for this sg + %b_init_offset = index.mul %sg_coord_y, %c32 + + // y offsets for B subtiles + %b_subtile0_y = index.add %n, %b_init_offset + %b_subtile1_y = index.add %b_subtile0_y, %c64 + + // init B subtiles + %b_init_subtile0 = xetile.init_tile %B[%c0, %b_subtile0_y] : memref<1024x1024xf16> + -> !xetile.tile<128x32xf16> + %b_init_subtile1 = xetile.init_tile %B[%c0, %b_subtile1_y] : memref<1024x1024xf16> + -> !xetile.tile<128x32xf16> + + // compute the value of C subtiles by iterating over subtiles in k-dimension and doing dpas + %out:8 = scf.for %k = %c0 to %c1024 step %c128 + iter_args(%a_subtile0 = %a_init_subtile0, %a_subtile1 = %a_init_subtile1, + %b_subtile0 = %b_init_subtile0, %b_subtile1 = %b_init_subtile1, + %c_value0 = %c_init_value0, %c_value1 = %c_init_value2, + %c_value2 = %c_init_value2, %c_value3 = %c_init_value3) + -> (!xetile.tile<32x128xf16>, + !xetile.tile<32x128xf16>, + !xetile.tile<128x32xf16>, + !xetile.tile<128x32xf16>, + vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>) { + + // load A subtiles + %a_value0 = xetile.load_tile %a_subtile0 : !xetile.tile<32x128xf16> + -> vector<32x128xf16> + %a_value1 = xetile.load_tile %a_subtile1 : !xetile.tile<32x128xf16> + -> vector<32x128xf16> + + // load B subtiles + %b_value0 = xetile.load_tile %b_subtile0 : !xetile.tile<128x32xf16> + -> vector<128x32xf16> + %b_value1 = xetile.load_tile %b_subtile1 : !xetile.tile<128x32xf16> + -> vector<128x32xf16> + + // perform 4 dpas ops and update the C subtiles + %c_new_value0 = xetile.tile_mma %a_value0, %b_value0, %c_value0 + : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %c_new_value1 = xetile.tile_mma %a_value0, %b_value1, %c_value1 + : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %c_new_value2 = xetile.tile_mma %a_value1, %b_value0, %c_value2 + : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + %c_new_value3 = xetile.tile_mma %a_value1, %b_value1, %c_value3 + : vector<32x128xf16>, vector<128x32xf16>, vector<32x32xf32> -> vector<32x32xf32> + + // update offsets for A subtiles + %a_next_subtile0 = xetile.update_tile_offset %a_subtile0, [%c0, %c128] + : !xetile.tile<32x128xf16>, index, index + -> !xetile.tile<32x128xf16> + %a_next_subtile1 = xetile.update_tile_offset %a_subtile1, [%c0, %c128] + : !xetile.tile<32x128xf16>, index, index + -> !xetile.tile<32x128xf16> + // update offsets for B subtiles + %b_next_subtile0 = xetile.update_tile_offset %b_subtile0, [%c128, %c0] + : !xetile.tile<128x32xf16>, index, index + -> !xetile.tile<128x32xf16> + %b_next_subtile1 = xetile.update_tile_offset %b_subtile1, [%c128, %c0] + : !xetile.tile<128x32xf16>, index, index + -> !xetile.tile<128x32xf16> + + // yield subtiles and partial C results + scf.yield %a_next_subtile0, %a_next_subtile1, %b_next_subtile0, %b_next_subtile1, + %c_new_value0, %c_new_value1, %c_new_value2, %c_new_value2 + : !xetile.tile<32x128xf16>, + !xetile.tile<32x128xf16>, + !xetile.tile<128x32xf16>, + !xetile.tile<128x32xf16>, + vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32>, vector<32x32xf32> + } + // store the C final subtiles into memory + xetile.store_tile %out#4, %c_init_subtile0 : vector<32x32xf32>, + !xetile.tile<32x32xf32> + xetile.store_tile %out#5, %c_init_subtile1 : vector<32x32xf32>, + !xetile.tile<32x32xf32> + xetile.store_tile %out#6, %c_init_subtile2 : vector<32x32xf32>, + !xetile.tile<32x32xf32> + xetile.store_tile %out#7, %c_init_subtile3 : vector<32x32xf32>, + !xetile.tile<32x32xf32> + + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<1024x1024xf16> + %B = memref.alloc() : memref<1024x1024xf16> + %C = memref.alloc() : memref<1024x1024xf32> + %C_ref = memref.alloc() : memref<1024x1024xf32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> + %t = arith.mulf %a_val, %b_val : f16 + %t_cast = arith.extf %t : f16 to f32 + %c_sum = arith.addf %t_cast, %c_partial : f32 + scf.yield %c_sum : f32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> + + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<1024x1024xf16> + memref.dealloc %B : memref<1024x1024xf16> + memref.dealloc %C : memref<1024x1024xf32> + memref.dealloc %C_ref : memref<1024x1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir new file mode 100644 index 000000000..fc7031839 --- /dev/null +++ b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_f16_f16_f32.mlir @@ -0,0 +1,172 @@ +// TODO: Add run commands +// RUN: + +// *** Experimental *** +// This example works at the work grpup level. This demonstrates how the user can specify the +// mapping for both subgroup within workgroup and work items within a single subgroup. The mapping +// of subgroups to subtiles are specified using `wg_map` and, work items to data elements mapping is +// specified using `sg_map`. Through this way, user has full control of how each work items works on +// exactly which data elements. XeTile fully honor the mapping provided by users. +// +// Note that lowering of this code to XeGPU is not supported yet because XeTile-XeGPU lowering assumes +// subgroup level programming at XeTile. + +#sg_map_a = #xetile.sg_map +#wg_map_a = #xetile.wg_map +#xe_map_a = #xetile.xe_map + +#sg_map_b = #xetile.sg_map +#wg_map_b = #xetile.wg_map +#xe_map_b = #xetile.xe_map + +#sg_map_c = #xetile.sg_map +#wg_map_c = #xetile.wg_map +#xe_map_c = #xetile.xe_map + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) -> memref<1024x1024xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %A, %A_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xf16> + memref.copy %B, %B_gpu : memref<1024x1024xf16> to memref<1024x1024xf16> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xf32> + memref.copy %C, %C_gpu : memref<1024x1024xf32> to memref<1024x1024xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c2, %c2, %c1) args(%A_gpu : memref<1024x1024xf16>, %B_gpu : memref<1024x1024xf16>, %C_gpu : memref<1024x1024xf32>) + gpu.dealloc %A_gpu : memref<1024x1024xf16> + gpu.dealloc %B_gpu : memref<1024x1024xf16> + return %C_gpu : memref<1024x1024xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c128 : index + %n = arith.muli %block_id_y, %c128 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> + -> !xetile.tile<128x128xf32, #xe_map_c> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x128xf32, #xe_map_c> + -> vector<128x128xf32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> + -> !xetile.tile<128x128xf16, #xe_map_a> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> + -> !xetile.tile<128x128xf16, #xe_map_b> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c128 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<128x128xf16, #xe_map_a>, + !xetile.tile<128x128xf16, #xe_map_b>, + vector<128x128xf32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<128x128xf16, #xe_map_a> + -> vector<128x128xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xf16, #xe_map_b> + -> vector<128x128xf16> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] + : !xetile.tile<128x128xf16, #xe_map_a>, index, index + -> !xetile.tile<128x128xf16, #xe_map_a> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c128, %c0] + : !xetile.tile<128x128xf16, #xe_map_b>, index, index + -> !xetile.tile<128x128xf16, #xe_map_b> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<128x128xf16, #xe_map_a>, + !xetile.tile<128x128xf16, #xe_map_b>, vector<128x128xf32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile : vector<128x128xf32>, + !xetile.tile<128x128xf32, #xe_map_c> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<1024x1024xf16> + %B = memref.alloc() : memref<1024x1024xf16> + %C = memref.alloc() : memref<1024x1024xf32> + %C_ref = memref.alloc() : memref<1024x1024xf32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<1024x1024xf16> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<1024x1024xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<1024x1024xf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<1024x1024xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xf32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> f32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xf16> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xf16> + %t = arith.mulf %a_val, %b_val : f16 + %t_cast = arith.extf %t : f16 to f32 + %c_sum = arith.addf %t_cast, %c_partial : f32 + scf.yield %c_sum : f32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xf32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xf16>, memref<1024x1024xf16>, memref<1024x1024xf32>) -> memref<1024x1024xf32> + %cast_C = memref.cast %2 : memref<1024x1024xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xf32> to memref<*xf32> + + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<1024x1024xf16> + memref.dealloc %B : memref<1024x1024xf16> + memref.dealloc %C : memref<1024x1024xf32> + memref.dealloc %C_ref : memref<1024x1024xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir new file mode 100644 index 000000000..3081f2572 --- /dev/null +++ b/test/Integration/Dialect/XeTile/wg_gemm_1kx1kx1k_i8_i8_i32.mlir @@ -0,0 +1,172 @@ +// TODO: Add run commands +// RUN: + +// *** Experimental *** +// This example works at the work grpup level. This demonstrates how the user can specify the +// mapping for both subgroup within workgroup and work items within a single subgroup. The mapping +// of subgroups to subtiles are specified using `wg_map` and, work items to data elements mapping is +// specified using `sg_map`. Through this way, user has full control of how each work items works on +// exactly which data elements. XeTile fully honor the mapping provided by users. +// +// Note that lowering of this code to XeGPU is not supported yet because XeTile-XeGPU lowering assumes +// subgroup level programming at XeTile. + +#sg_map_a = #xetile.sg_map +#wg_map_a = #xetile.wg_map +#xe_map_a = #xetile.xe_map + +#sg_map_b = #xetile.sg_map +#wg_map_b = #xetile.wg_map +#xe_map_b = #xetile.xe_map + +#sg_map_c = #xetile.sg_map +#wg_map_c = #xetile.wg_map +#xe_map_c = #xetile.xe_map + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> + memref.copy %A, %A_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> + %B_gpu = gpu.alloc host_shared () : memref<1024x1024xi8> + memref.copy %B, %B_gpu : memref<1024x1024xi8> to memref<1024x1024xi8> + %C_gpu = gpu.alloc host_shared () : memref<1024x1024xi32> + memref.copy %C, %C_gpu : memref<1024x1024xi32> to memref<1024x1024xi32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c8, %c8, %c1) threads in (%c2, %c2, %c1) args(%A_gpu : memref<1024x1024xi8>, %B_gpu : memref<1024x1024xi8>, %C_gpu : memref<1024x1024xi32>) + gpu.dealloc %A_gpu : memref<1024x1024xi8> + gpu.dealloc %B_gpu : memref<1024x1024xi8> + return %C_gpu : memref<1024x1024xi32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c128 : index + %n = arith.muli %block_id_y, %c128 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xi32> + -> !xetile.tile<128x128xi32, #xe_map_c> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x128xi32, #xe_map_c> + -> vector<128x128xi32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xi8> + -> !xetile.tile<128x128xi8, #xe_map_a> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xi8> + -> !xetile.tile<128x128xi8, #xe_map_b> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c128 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<128x128xi8, #xe_map_a>, + !xetile.tile<128x128xi8, #xe_map_b>, + vector<128x128xi32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<128x128xi8, #xe_map_a> + -> vector<128x128xi8> + %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xi8, #xe_map_b> + -> vector<128x128xi8> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<128x128xi8>, vector<128x128xi8>, vector<128x128xi32> -> vector<128x128xi32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] + : !xetile.tile<128x128xi8, #xe_map_a>, index, index + -> !xetile.tile<128x128xi8, #xe_map_a> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c128, %c0] + : !xetile.tile<128x128xi8, #xe_map_b>, index, index + -> !xetile.tile<128x128xi8, #xe_map_b> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<128x128xi8, #xe_map_a>, + !xetile.tile<128x128xi8, #xe_map_b>, vector<128x128xi32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile : vector<128x128xi32>, + !xetile.tile<128x128xi32, #xe_map_c> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %ci_0 = arith.constant 0 : i8 + %ci_1 = arith.constant 1 : i8 + %A = memref.alloc() : memref<1024x1024xi8> + %B = memref.alloc() : memref<1024x1024xi8> + %C = memref.alloc() : memref<1024x1024xi32> + %C_ref = memref.alloc() : memref<1024x1024xi32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %val = index.castu %j : index to i8 + memref.store %val, %A[%i, %j] : memref<1024x1024xi8> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %ci_1, %B[%i, %j] : memref<1024x1024xi8> + } else { + memref.store %ci_0, %B[%i, %j] : memref<1024x1024xi8> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_i32 = arith.constant 0: i32 + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + memref.store %c0_i32, %C[%i, %j] : memref<1024x1024xi32> + memref.store %c0_i32, %C_ref[%i, %j] : memref<1024x1024xi32> + } + } + // compute C for reference + scf.for %i = %c0 to %c1024 step %c1 { + scf.for %j = %c0 to %c1024 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<1024x1024xi32> + %c_val = scf.for %k = %c0 to %c1024 step %c1 iter_args(%c_partial = %c_curr) -> i32 { + %a_val = memref.load %A[%i, %k] : memref<1024x1024xi8> + %b_val = memref.load %B[%k, %j] : memref<1024x1024xi8> + %a_val_i32 = arith.extui %a_val : i8 to i32 + %b_val_i32 = arith.extui %b_val : i8 to i32 + %t = arith.muli %a_val_i32, %b_val_i32 : i32 + %c_sum = arith.addi %t, %c_partial : i32 + scf.yield %c_sum : i32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<1024x1024xi32> + } + } + %2 = call @test(%A, %B, %C) : (memref<1024x1024xi8>, memref<1024x1024xi8>, memref<1024x1024xi32>) -> memref<1024x1024xi32> + %cast_C = memref.cast %2 : memref<1024x1024xi32> to memref<*xi32> + %cast_C_ref = memref.cast %C_ref : memref<1024x1024xi32> to memref<*xi32> + + call @printAllcloseI32(%cast_C, %cast_C_ref) : (memref<*xi32>, memref<*xi32>) -> () + memref.dealloc %A : memref<1024x1024xi8> + memref.dealloc %B : memref<1024x1024xi8> + memref.dealloc %C : memref<1024x1024xi32> + memref.dealloc %C_ref : memref<1024x1024xi32> + return + } + func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefI8(memref<*xi8>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseI32(memref<*xi32>, memref<*xi32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir new file mode 100644 index 000000000..1564e9685 --- /dev/null +++ b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_f16_f16_f32.mlir @@ -0,0 +1,171 @@ +// TODO: Add run commands +// RUN: + +// *** Experimental *** +// This example works at the work grpup level. This demonstrates how the user can specify the +// mapping for both subgroup within workgroup and work items within a single subgroup. The mapping +// of subgroups to subtiles are specified using `wg_map` and, work items to data elements mapping is +// specified using `sg_map`. Through this way, user has full control of how each work items works on +// exactly which data elements. XeTile fully honor the mapping provided by users. +// +// Note that lowering of this code to XeGPU is not supported yet because XeTile-XeGPU lowering assumes +// subgroup level programming at XeTile. + + +#sg_map_a = #xetile.sg_map +#wg_map_a = #xetile.wg_map +#xe_map_a = #xetile.xe_map + +#sg_map_b = #xetile.sg_map +#wg_map_b = #xetile.wg_map +#xe_map_b = #xetile.xe_map + +#sg_map_c = #xetile.sg_map +#wg_map_c = #xetile.wg_map +#xe_map_c = #xetile.xe_map + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) -> memref<4096x4096xf32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> + memref.copy %A, %A_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> + %B_gpu = gpu.alloc host_shared () : memref<4096x4096xf16> + memref.copy %B, %B_gpu : memref<4096x4096xf16> to memref<4096x4096xf16> + %C_gpu = gpu.alloc host_shared () : memref<4096x4096xf32> + memref.copy %C, %C_gpu : memref<4096x4096xf32> to memref<4096x4096xf32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c4, %c4, %c1) args(%A_gpu : memref<4096x4096xf16>, %B_gpu : memref<4096x4096xf16>, %C_gpu : memref<4096x4096xf32>) + gpu.dealloc %A_gpu : memref<4096x4096xf16> + gpu.dealloc %B_gpu : memref<4096x4096xf16> + return %C_gpu : memref<4096x4096xf32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c256 : index + %n = arith.muli %block_id_y, %c256 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> + -> !xetile.tile<256x256xf32, #xe_map_c> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xf32, #xe_map_c> + -> vector<256x256xf32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xf16> + -> !xetile.tile<256x256xf16, #xe_map_a> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xf16> + -> !xetile.tile<256x256xf16, #xe_map_b> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c4096 step %c256 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<256x256xf16, #xe_map_a>, + !xetile.tile<256x256xf16, #xe_map_b>, + vector<256x256xf32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<256x256xf16, #xe_map_a> + -> vector<256x256xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<256x256xf16, #xe_map_b> + -> vector<256x256xf16> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<256x256xf16>, vector<256x256xf16>, vector<256x256xf32> -> vector<256x256xf32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c256] + : !xetile.tile<256x256xf16, #xe_map_a>, index, index + -> !xetile.tile<256x256xf16, #xe_map_a> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c256, %c0] + : !xetile.tile<256x256xf16, #xe_map_b>, index, index + -> !xetile.tile<256x256xf16, #xe_map_b> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<256x256xf16, #xe_map_a>, + !xetile.tile<256x256xf16, #xe_map_b>, vector<256x256xf32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile : vector<256x256xf32>, + !xetile.tile<256x256xf32, #xe_map_c> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4096 = arith.constant 4096 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<4096x4096xf16> + %B = memref.alloc() : memref<4096x4096xf16> + %C = memref.alloc() : memref<4096x4096xf32> + %C_ref = memref.alloc() : memref<4096x4096xf32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %t = index.castu %j : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<4096x4096xf16> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %cf_1, %B[%i, %j] : memref<4096x4096xf16> + } else { + memref.store %cf_0, %B[%i, %j] : memref<4096x4096xf16> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_f32 = arith.constant 0.0 : f32 + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + memref.store %c0_f32, %C[%i, %j] : memref<4096x4096xf32> + memref.store %c0_f32, %C_ref[%i, %j] : memref<4096x4096xf32> + } + } + // compute C for reference + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<4096x4096xf32> + %c_val = scf.for %k = %c0 to %c4096 step %c1 iter_args(%c_partial = %c_curr) -> f32 { + %a_val = memref.load %A[%i, %k] : memref<4096x4096xf16> + %b_val = memref.load %B[%k, %j] : memref<4096x4096xf16> + %t = arith.mulf %a_val, %b_val : f16 + %t_cast = arith.extf %t : f16 to f32 + %c_sum = arith.addf %t_cast, %c_partial : f32 + scf.yield %c_sum : f32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<4096x4096xf32> + } + } + %2 = call @test(%A, %B, %C) : (memref<4096x4096xf16>, memref<4096x4096xf16>, memref<4096x4096xf32>) -> memref<4096x4096xf32> + %cast_C = memref.cast %2 : memref<4096x4096xf32> to memref<*xf32> + %cast_C_ref = memref.cast %C_ref : memref<4096x4096xf32> to memref<*xf32> + + call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> () + memref.dealloc %A : memref<4096x4096xf16> + memref.dealloc %B : memref<4096x4096xf16> + memref.dealloc %C : memref<4096x4096xf32> + memref.dealloc %C_ref : memref<4096x4096xf32> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir new file mode 100644 index 000000000..0fd16fdf3 --- /dev/null +++ b/test/Integration/Dialect/XeTile/wg_gemm_4kx4kx4k_i8_i8_i32.mlir @@ -0,0 +1,170 @@ +// TODO: Add run commands +// RUN: + +// *** Experimental *** +// This example works at the work grpup level. This demonstrates how the user can specify the +// mapping for both subgroup within workgroup and work items within a single subgroup. The mapping +// of subgroups to subtiles are specified using `wg_map` and, work items to data elements mapping is +// specified using `sg_map`. Through this way, user has full control of how each work items works on +// exactly which data elements. XeTile fully honor the mapping provided by users. +// +// Note that lowering of this code to XeGPU is not supported yet because XeTile-XeGPU lowering assumes +// subgroup level programming at XeTile. + +#sg_map_a = #xetile.sg_map +#wg_map_a = #xetile.wg_map +#xe_map_a = #xetile.xe_map + +#sg_map_b = #xetile.sg_map +#wg_map_b = #xetile.wg_map +#xe_map_b = #xetile.xe_map + +#sg_map_c = #xetile.sg_map +#wg_map_c = #xetile.wg_map +#xe_map_c = #xetile.xe_map + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<4096x4096xi8>, %B: memref<4096x4096xi8>, %C: memref<4096x4096xi32>) -> memref<4096x4096xi32> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %A_gpu = gpu.alloc host_shared () : memref<4096x4096xi8> + memref.copy %A, %A_gpu : memref<4096x4096xi8> to memref<4096x4096xi8> + %B_gpu = gpu.alloc host_shared () : memref<4096x4096xi8> + memref.copy %B, %B_gpu : memref<4096x4096xi8> to memref<4096x4096xi8> + %C_gpu = gpu.alloc host_shared () : memref<4096x4096xi32> + memref.copy %C, %C_gpu : memref<4096x4096xi32> to memref<4096x4096xi32> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c16, %c16, %c1) threads in (%c4, %c4, %c1) args(%A_gpu : memref<4096x4096xi8>, %B_gpu : memref<4096x4096xi8>, %C_gpu : memref<4096x4096xi32>) + gpu.dealloc %A_gpu : memref<4096x4096xi8> + gpu.dealloc %B_gpu : memref<4096x4096xi8> + return %C_gpu : memref<4096x4096xi32> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<4096x4096xi8>, %B: memref<4096x4096xi8>, %C: memref<4096x4096xi32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c256 : index + %n = arith.muli %block_id_y, %c256 : index + // intialize C tile and load it + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xi32> + -> !xetile.tile<256x256xi32, #xe_map_c> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<256x256xi32, #xe_map_c> + -> vector<256x256xi32> + // initalize A and B tiles + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xi8> + -> !xetile.tile<256x256xi8, #xe_map_a> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xi8> + -> !xetile.tile<256x256xi8, #xe_map_b> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c4096 step %c256 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<256x256xi8, #xe_map_a>, + !xetile.tile<256x256xi8, #xe_map_b>, + vector<256x256xi32>) { + + // load A and B tiles + %a_value = xetile.load_tile %a_tile : !xetile.tile<256x256xi8, #xe_map_a> + -> vector<256x256xi8> + %b_value = xetile.load_tile %b_tile : !xetile.tile<256x256xi8, #xe_map_b> + -> vector<256x256xi8> + // perform dpas and accumulate + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<256x256xi8>, vector<256x256xi8>, vector<256x256xi32> -> vector<256x256xi32> + // update the offsets for A and B tiles + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c256] + : !xetile.tile<256x256xi8, #xe_map_a>, index, index + -> !xetile.tile<256x256xi8, #xe_map_a> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c256, %c0] + : !xetile.tile<256x256xi8, #xe_map_b>, index, index + -> !xetile.tile<256x256xi8, #xe_map_b> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<256x256xi8, #xe_map_a>, + !xetile.tile<256x256xi8, #xe_map_b>, vector<256x256xi32> + } + // store the final accumulated C tile result back to memory + xetile.store_tile %out#2, %c_init_tile : vector<256x256xi32>, + !xetile.tile<256x256xi32, #xe_map_c> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4096 = arith.constant 4096 : index + %ci_0 = arith.constant 0 : i8 + %ci_1 = arith.constant 1 : i8 + %A = memref.alloc() : memref<4096x4096xi8> + %B = memref.alloc() : memref<4096x4096xi8> + %C = memref.alloc() : memref<4096x4096xi32> + %C_ref = memref.alloc() : memref<4096x4096xi32> + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %val = index.castu %j : index to i8 + memref.store %val, %A[%i, %j] : memref<4096x4096xi8> + } + } + // make matrix B an identity matrix + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %i_i32 = index.castu %i : index to i32 + %j_i32 = index.castu %j : index to i32 + %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32 + + scf.if %i_j_same { + memref.store %ci_1, %B[%i, %j] : memref<4096x4096xi8> + } else { + memref.store %ci_0, %B[%i, %j] : memref<4096x4096xi8> + } + } + } + // intialize matrix C and C_ref ; C[i, j] = 0 + %c0_i32 = arith.constant 0: i32 + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + memref.store %c0_i32, %C[%i, %j] : memref<4096x4096xi32> + memref.store %c0_i32, %C_ref[%i, %j] : memref<4096x4096xi32> + } + } + // compute C for reference + scf.for %i = %c0 to %c4096 step %c1 { + scf.for %j = %c0 to %c4096 step %c1 { + %c_curr = memref.load %C_ref[%i, %j] : memref<4096x4096xi32> + %c_val = scf.for %k = %c0 to %c4096 step %c1 iter_args(%c_partial = %c_curr) -> i32 { + %a_val = memref.load %A[%i, %k] : memref<4096x4096xi8> + %b_val = memref.load %B[%k, %j] : memref<4096x4096xi8> + %a_val_i32 = arith.extui %a_val : i8 to i32 + %b_val_i32 = arith.extui %b_val : i8 to i32 + %t = arith.muli %a_val_i32, %b_val_i32 : i32 + %c_sum = arith.addi %t, %c_partial : i32 + scf.yield %c_sum : i32 + } + memref.store %c_val , %C_ref[%i, %j] : memref<4096x4096xi32> + } + } + %2 = call @test(%A, %B, %C) : (memref<4096x4096xi8>, memref<4096x4096xi8>, memref<4096x4096xi32>) -> memref<4096x4096xi32> + %cast_C = memref.cast %2 : memref<4096x4096xi32> to memref<*xi32> + %cast_C_ref = memref.cast %C_ref : memref<4096x4096xi32> to memref<*xi32> + + call @printAllcloseI32(%cast_C, %cast_C_ref) : (memref<*xi32>, memref<*xi32>) -> () + memref.dealloc %A : memref<4096x4096xi8> + memref.dealloc %B : memref<4096x4096xi8> + memref.dealloc %C : memref<4096x4096xi32> + memref.dealloc %C_ref : memref<4096x4096xi32> + return + } + func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefI8(memref<*xi8>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseI32(memref<*xi32>, memref<*xi32>) attributes {llvm.emit_c_interface} +}