Skip to content

Commit

Permalink
[XeTileToXeGPU] Add elementwise integer arith op patterns. (#952)
Browse files Browse the repository at this point in the history
* [XeTileToXeGPU] Add elementwise integer arith op patterns.

These patterns allows the elementwise ops to be blocked and updated.
  • Loading branch information
mshahneo authored Nov 5, 2024
1 parent ffb8bf0 commit c741d7d
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 6 deletions.
24 changes: 20 additions & 4 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>(
patterns.getContext(), converter, analysis);

// Element-wise math operations
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
ElementWiseOpPattern<mlir::math::SinOp, 1>,
Expand All @@ -1267,16 +1269,30 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
ElementWiseOpPattern<mlir::math::LogOp, 1>,
ElementWiseOpPattern<mlir::math::RsqrtOp, 1>,
ElementWiseOpPattern<mlir::math::ErfOp, 1>,
ElementWiseOpPattern<mlir::arith::AddFOp, 2>,
ElementWiseOpPattern<mlir::arith::AndIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemFOp, 2>,
ElementWiseOpPattern<mlir::math::PowFOp, 2>>(
patterns.getContext(), converter, analysis);

// Element-wise arithmetic operations
patterns.insert<ElementWiseOpPattern<mlir::arith::AddFOp, 2>,
ElementWiseOpPattern<mlir::arith::AddIOp, 2>,
ElementWiseOpPattern<mlir::arith::DivFOp, 2>,
ElementWiseOpPattern<mlir::arith::DivSIOp, 2>,
ElementWiseOpPattern<mlir::arith::DivUIOp, 2>,
ElementWiseOpPattern<mlir::arith::MulFOp, 2>,
ElementWiseOpPattern<mlir::arith::MulIOp, 2>,
ElementWiseOpPattern<mlir::arith::MaximumFOp, 2>,
ElementWiseOpPattern<mlir::arith::MaxSIOp, 2>,
ElementWiseOpPattern<mlir::arith::MaxUIOp, 2>,
ElementWiseOpPattern<mlir::arith::MinimumFOp, 2>,
ElementWiseOpPattern<mlir::arith::MinSIOp, 2>,
ElementWiseOpPattern<mlir::arith::MinUIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemFOp, 2>,
ElementWiseOpPattern<mlir::arith::RemSIOp, 2>,
ElementWiseOpPattern<mlir::arith::RemUIOp, 2>,
ElementWiseOpPattern<mlir::arith::SubFOp, 2>,
ElementWiseOpPattern<mlir::arith::SubIOp, 2>,
ElementWiseOpPattern<mlir::arith::AndIOp, 2>,
ElementWiseOpPattern<mlir::arith::XOrIOp, 2>,
ElementWiseOpPattern<mlir::math::PowFOp, 2>,
ElementWiseOpPattern<mlir::arith::SelectOp, 3>>(
patterns.getContext(), converter, analysis);
patterns.insert<TypecastOpPattern<mlir::arith::ExtFOp>,
Expand Down
24 changes: 22 additions & 2 deletions lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,20 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
// Arith ops
addDynamicallyLegalOp<mlir::arith::AddFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::AddIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::AndIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::DivFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::DivSIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::DivUIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MulFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MulIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::CmpFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::CmpIOp>(
Expand All @@ -105,16 +113,28 @@ class XeTileConversionTarget : public mlir::ConversionTarget {
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::SubFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::SubIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MaximumFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MaxSIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MaxUIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::RemFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::NegFOp>(
addDynamicallyLegalOp<mlir::arith::RemSIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MaximumFOp>(
addDynamicallyLegalOp<mlir::arith::RemUIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::NegFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MinimumFOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MinSIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::MinUIOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::SelectOp>(
[&](mlir::Operation *op) -> bool { return isLegalElementWiseOp(op); });
addDynamicallyLegalOp<mlir::arith::ExtFOp>(
Expand Down
32 changes: 32 additions & 0 deletions test/Conversion/XeTileToXeGPU/elementwise_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,38 @@
gpu.return
}

gpu.func @arith_binary_ops_int() {
%0 = arith.constant dense<1>: vector<4x4x16x16xi16>
%1 = arith.constant dense<2>: vector<64x4x1x16xi16>
%2 = xetile.tile_unpack %0 {inner_blocks = array<i64: 16, 16>}: vector<4x4x16x16xi16> -> vector<64x64xi16>
%3 = xetile.tile_pack %2 {inner_blocks = array<i64: 1, 16>}: vector<64x64xi16> -> vector<64x4x1x16xi16>
// CHECK-COUNT-256: arith.addi {{.*}}, {{.*}} : vector<1x16xi16>
// CHECK-COUNT-256: arith.subi
// CHECK-COUNT-256: arith.muli
// CHECK-COUNT-256: arith.maxsi
// CHECK-COUNT-256: arith.maxui
// CHECK-COUNT-256: arith.minsi
// CHECK-COUNT-256: arith.minui
// CHECK-COUNT-256: arith.divsi
// CHECK-COUNT-256: arith.divui
// CHECK-COUNT-256: arith.remsi
// CHECK-COUNT-256: arith.remui
// CHECK-COUNT-256: arith.andi
%result = arith.addi %3, %1 : vector<64x4x1x16xi16>
%subi_result = arith.subi %3, %1 : vector<64x4x1x16xi16>
%muli_result = arith.muli %subi_result, %1 : vector<64x4x1x16xi16>
%maxsi_result = arith.maxsi %muli_result, %1 : vector<64x4x1x16xi16>
%maxui_result = arith.maxui %muli_result, %1 : vector<64x4x1x16xi16>
%minsi_result = arith.minsi %maxsi_result, %muli_result : vector<64x4x1x16xi16>
%minui_result = arith.minui %maxui_result, %muli_result : vector<64x4x1x16xi16>
%divsi_result = arith.divsi %minui_result, %1 : vector<64x4x1x16xi16>
%divui_result = arith.divui %minui_result, %1 : vector<64x4x1x16xi16>
%remsi_result = arith.remsi %minsi_result, %divsi_result : vector<64x4x1x16xi16>
%remui_result = arith.remui %minui_result, %divui_result : vector<64x4x1x16xi16>
%and_result = arith.andi %remsi_result, %remui_result : vector<64x4x1x16xi16>
gpu.return
}

gpu.func @arith_xori_ops() {
%0 = arith.constant dense<1>: vector<4x4x16x16xi16>
%1 = arith.constant dense<2>: vector<64x4x1x16xi16>
Expand Down
78 changes: 78 additions & 0 deletions test/Integration/Dialect/XeTile/eltwise_int_ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xetile-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module @eltwise_int attributes {gpu.container_module} {
memref.global "private" constant @__constant_5_1024x1024xi32 : memref<1024x1024xi32> = dense<5>
memref.global "private" constant @__constant_2_1024x1024xi32 : memref<1024x1024xi32> = dense<2>

func.func @eltwise_int_test(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>) -> memref<1024x1024xi32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index

%arg0_gpu = gpu.alloc host_shared () : memref<1024x1024xi32>
memref.copy %arg0, %arg0_gpu : memref<1024x1024xi32> to memref<1024x1024xi32>

%arg1_gpu = gpu.alloc host_shared () : memref<1024x1024xi32>
memref.copy %arg1, %arg1_gpu : memref<1024x1024xi32> to memref<1024x1024xi32>

%result = gpu.alloc host_shared () : memref<1024x1024xi32>

gpu.launch_func @eltwise_int::@eltwise_int blocks in (%c64, %c32, %c1) threads in (%c1, %c1, %c1) args(%arg0_gpu : memref<1024x1024xi32>, %arg1_gpu : memref<1024x1024xi32>, %result : memref<1024x1024xi32>)

gpu.dealloc %arg0_gpu : memref<1024x1024xi32>
gpu.dealloc %arg1_gpu : memref<1024x1024xi32>
return %result : memref<1024x1024xi32>

}

gpu.module @eltwise_int attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Bfloat16ConversionINTEL, BFloat16TypeKHR, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, VectorAnyINTEL, VectorComputeINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_bfloat16, SPV_KHR_expect_assume, SPV_INTEL_bfloat16_conversion, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @eltwise_int(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>, %arg2: memref<1024x1024xi32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 32, 1>, known_grid_size = array<i32: 1, 1, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index

%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y

%m = arith.muli %block_id_x, %c16 : index
%n = arith.muli %block_id_y, %c32 : index

%1 = xetile.init_tile %arg0[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32>
%2 = xetile.load_tile %1: !xetile.tile<16x32xi32> -> vector<16x32xi32>
%3 = xetile.init_tile %arg1[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32>
%4 = xetile.load_tile %3: !xetile.tile<16x32xi32> -> vector<16x32xi32>
%result_add = arith.addi %2, %4: vector<16x32xi32> //=7
%result_sub = arith.subi %2, %4: vector<16x32xi32> //=3
%result_mul = arith.muli %result_add, %result_sub: vector<16x32xi32> //=21
%result_sdiv = arith.divsi %result_mul, %result_add: vector<16x32xi32> //=3
%result_udiv = arith.divui %result_mul, %result_add: vector<16x32xi32> //=3
%result_srem = arith.remsi %result_sdiv, %result_mul: vector<16x32xi32> //=3
%result_urem = arith.remui %result_udiv, %result_srem: vector<16x32xi32> //=0
%result = arith.addi %result_srem, %result_urem: vector<16x32xi32> //=3
%store_tile = xetile.init_tile %arg2[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<16x32xi32>
xetile.store_tile %result, %store_tile: vector<16x32xi32>, !xetile.tile<16x32xi32>
gpu.return
}
}

func.func @main() attributes {llvm.emit_c_interface} {
%A = memref.get_global @__constant_5_1024x1024xi32 : memref<1024x1024xi32>
%B = memref.get_global @__constant_2_1024x1024xi32 : memref<1024x1024xi32>

%c0_i32 = arith.constant 0 : i32

%result = call @eltwise_int_test(%A, %B) : (memref<1024x1024xi32>, memref<1024x1024xi32>) -> memref<1024x1024xi32>
%result_cast = memref.cast %result : memref<1024x1024xi32> to memref<*xi32>
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
// CHECK-COUNT-1048576: 3
call @printMemrefI32(%result_cast) : (memref<*xi32>) -> ()

return
}
func.func private @printMemrefI32(memref<*xi32>) attributes {llvm.emit_c_interface}
}

0 comments on commit c741d7d

Please sign in to comment.