|
| 1 | +// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ |
| 2 | +// RUN: --runner imex-cpu-runner -e main \ |
| 3 | +// RUN: --entry-point-result=void \ |
| 4 | +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck |
| 5 | +// RUN: IMEX_USE_IGC_VECTOR_BACK_END=1 %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \ |
| 6 | +// RUN: --runner imex-cpu-runner -e main \ |
| 7 | +// RUN: --entry-point-result=void \ |
| 8 | +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck |
| 9 | + |
| 10 | +// NOTES : |
| 11 | +// This example assumes one subgroup per one workgroup and the kernel specifies the computation |
| 12 | +// done by a single subgroup. |
| 13 | + |
| 14 | +module @gemm attributes {gpu.container_module} { |
| 15 | + // a test case case return the transpose of A, which is viewed as memref<32x32xf16>. |
| 16 | + // it uses one workgroup containing 32 subgroups, organized as (8x4), so each subgroup |
| 17 | + // works on a 4x8 tile of A. It used SLM to do the transpose, to evaluate the functionality |
| 18 | + // of the SLM operations. |
| 19 | + func.func @test(%A: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { |
| 20 | + %c1 = arith.constant 1 : index |
| 21 | + %c4 = arith.constant 4 : index |
| 22 | + %c8 = arith.constant 8 : index |
| 23 | + %A_gpu = gpu.alloc host_shared () : memref<32x32xf16> |
| 24 | + memref.copy %A, %A_gpu : memref<32x32xf16> to memref<32x32xf16> |
| 25 | + %B_gpu = gpu.alloc host_shared () : memref<32x32xf16> |
| 26 | + gpu.launch_func @test_kernel::@trans_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c8, %c1) args(%A_gpu : memref<32x32xf16>, %B_gpu : memref<32x32xf16>) |
| 27 | + gpu.dealloc %A_gpu : memref<32x32xf16> |
| 28 | + return %B_gpu : memref<32x32xf16> |
| 29 | + } |
| 30 | + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} { |
| 31 | + gpu.func @trans_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { |
| 32 | + %c0 = arith.constant 0 : index |
| 33 | + %c2 = arith.constant 2 : index |
| 34 | + %c3 = arith.constant 3 : index |
| 35 | + %c4 = arith.constant 4 : index |
| 36 | + %c8 = arith.constant 8 : index |
| 37 | + %c128 = arith.constant 128 : index |
| 38 | + %c256 = arith.constant 256 : index |
| 39 | + |
| 40 | + %sgid = gpu.subgroup_id : index |
| 41 | + // %tid_y = arith.divui %sgid, %c4 : index |
| 42 | + // %tid_x = arith.remui %sgid, %c4 : index |
| 43 | + %tid_y = arith.shrui %sgid, %c2 : index |
| 44 | + %tid_x = arith.andi %sgid, %c3 : index |
| 45 | + |
| 46 | + %off_y = arith.muli %tid_y, %c4 : index |
| 47 | + %off_x = arith.muli %tid_x, %c8 : index |
| 48 | + |
| 49 | + // load data from global memory using block load |
| 50 | + %a_tile = xetile.init_tile %A[%off_y, %off_x] : memref<32x32xf16> -> !xetile.tile<4x8xf16> |
| 51 | + %data = xetile.load_tile %a_tile : !xetile.tile<4x8xf16> -> vector<4x8xf16> |
| 52 | + |
| 53 | + %slm = memref.alloc() : memref<32x32xf16, 3> |
| 54 | + %cast = memref.reinterpret_cast %slm to offset: [0], sizes: [1024], strides: [1] : memref<32x32xf16, 3> to memref<1024xf16, 3> |
| 55 | + %mask = arith.constant dense<true>: vector<4x8xi1> |
| 56 | + |
| 57 | + // store data to slm using original layout |
| 58 | + %base_indices = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7], |
| 59 | + [32, 33, 34, 35, 36, 37, 38, 39], |
| 60 | + [64, 65, 66, 67, 68, 69, 70, 71], |
| 61 | + [96, 97, 98, 99, 100, 101, 102, 103]]>: vector<4x8xindex> |
| 62 | + %off_y2 = arith.muli %tid_y, %c128 : index |
| 63 | + %offset = arith.addi %off_y2, %off_x : index |
| 64 | + %offsets = vector.splat %offset: vector<4x8xindex> |
| 65 | + %indices = arith.addi %base_indices, %offsets : vector<4x8xindex> |
| 66 | + %st_tile = xetile.init_tile %cast, %indices : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>> |
| 67 | + xetile.store %data, %st_tile, %mask : vector<4x8xf16>, !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>>, vector<4x8xi1> |
| 68 | + |
| 69 | + gpu.barrier |
| 70 | + |
| 71 | + // load data from slm using indices with transpose effects |
| 72 | + %trans_base_indices = arith.constant dense<[[0, 32, 64, 96, 128, 160, 192, 224], |
| 73 | + [1, 33, 65, 97, 129, 161, 193, 225], |
| 74 | + [2, 34, 66, 98, 130, 162, 194, 226], |
| 75 | + [3, 35, 67, 99, 131, 163, 195, 227]]>: vector<4x8xindex> |
| 76 | + |
| 77 | + %trans_off_x = arith.muli %tid_x, %c256 : index |
| 78 | + %trans_off_y = arith.muli %tid_y, %c4 : index |
| 79 | + %trans_off = arith.addi %trans_off_x, %trans_off_y : index |
| 80 | + %trans_offsets = vector.splat %trans_off: vector<4x8xindex> |
| 81 | + %trans_indices = arith.addi %trans_base_indices, %trans_offsets : vector<4x8xindex> |
| 82 | + %ld_tile = xetile.init_tile %cast, %trans_indices : memref<1024xf16, 3>, vector<4x8xindex> -> !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>> |
| 83 | + %d = xetile.load %ld_tile, %mask : !xetile.tile<4x8xf16, #xetile.tile_attr<scattered = true, memory_space=3>>, vector<4x8xi1> -> vector<4x8xf16> |
| 84 | + |
| 85 | + %b_tile = xetile.init_tile %B[%off_y, %off_x] : memref<32x32xf16> -> !xetile.tile<4x8xf16> |
| 86 | + xetile.store_tile %d, %b_tile: vector<4x8xf16>, !xetile.tile<4x8xf16> |
| 87 | + gpu.return |
| 88 | + } |
| 89 | + } |
| 90 | + func.func @main() attributes {llvm.emit_c_interface} { |
| 91 | + %c0 = arith.constant 0 : index |
| 92 | + %c1 = arith.constant 1 : index |
| 93 | + %c32 = arith.constant 32 : index |
| 94 | + %cf_0 = arith.constant 0.0 : bf16 |
| 95 | + %cf_1 = arith.constant 1.0 : bf16 |
| 96 | + %A = memref.alloc() : memref<32x32xf16> |
| 97 | + %Ref = memref.alloc() : memref<32x32xf32> |
| 98 | + // intialize matrix A ; |
| 99 | + scf.for %i = %c0 to %c32 step %c1 { |
| 100 | + scf.for %j = %c0 to %c32 step %c1 { |
| 101 | + %m = arith.muli %i, %c32 : index |
| 102 | + %a = arith.addi %m, %j : index |
| 103 | + %v = index.castu %a : index to i16 |
| 104 | + %val = arith.uitofp %v : i16 to f16 |
| 105 | + memref.store %val, %A[%i, %j] : memref<32x32xf16> |
| 106 | + %v32 = index.castu %a : index to i32 |
| 107 | + %val32 = arith.uitofp %v32 : i32 to f32 |
| 108 | + memref.store %val32, %Ref[%j, %i] : memref<32x32xf32> |
| 109 | + } |
| 110 | + } |
| 111 | + %B = call @test(%A) : (memref<32x32xf16>) -> memref<32x32xf16> |
| 112 | + %cast = memref.cast %B : memref<32x32xf16> to memref<*xf16> |
| 113 | + %Ref_cast = memref.cast %Ref : memref<32x32xf32> to memref<*xf32> |
| 114 | + //CHECK: [ALLCLOSE: TRUE] |
| 115 | + call @printAllcloseF16(%cast, %Ref_cast) : (memref<*xf16>, memref<*xf32>) -> () |
| 116 | + memref.dealloc %A : memref<32x32xf16> |
| 117 | + memref.dealloc %Ref : memref<32x32xf32> |
| 118 | + return |
| 119 | + } |
| 120 | + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} |
| 121 | + func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} |
| 122 | +} |
0 commit comments