Skip to content

Commit

Permalink
Add support for xetile.atomic_rmw op in init-duplicate pass (#975)
Browse files Browse the repository at this point in the history
* Add support for xetile.atomic_rmw op in init-duplicate pass

* Fix pre-commit
  • Loading branch information
nbpatel authored Dec 3, 2024
1 parent b5483b9 commit e93172a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
56 changes: 55 additions & 1 deletion include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class TileUsageAnalysis {
Usage[op] |= (uint)UsageType::PREFETCH;
} else if (llvm::isa<imex::xetile::StoreTileOp>(user)) {
Usage[op] |= (uint)UsageType::STORE;
} else if (llvm::isa<imex::xetile::AtomicRMWOp>(user)) {
Usage[op] |= (uint)UsageType::ATOMICRMW;
} else if (llvm::isa<imex::xetile::UpdateTileOffsetOp>(user)) {
Usage[op] |= (uint)UsageType::OTHER;
} else if (auto forOp =
Expand Down Expand Up @@ -162,6 +164,17 @@ class TileUsageAnalysis {
return false;
}

bool isForAtomicRMW(imex::xetile::InitTileOp op) {
if (Usage.count(op)) {
bool load = Usage[op] & UsageType::LOAD;
bool store = Usage[op] & UsageType::STORE;
bool prefetch = Usage[op] & UsageType::PREFETCH;
bool atomic_rmw = Usage[op] & UsageType::ATOMICRMW;
return !load && !store && !prefetch && atomic_rmw;
}
return false;
}

//
bool isForLoadAndPrefetch(imex::xetile::InitTileOp op) {
if (Usage.count(op)) {
Expand Down Expand Up @@ -193,6 +206,28 @@ class TileUsageAnalysis {
return false;
}

bool isForLoadAndAtomicRMW(imex::xetile::InitTileOp op) {
if (Usage.count(op)) {
bool load = Usage[op] & UsageType::LOAD;
bool store = Usage[op] & UsageType::STORE;
bool prefetch = Usage[op] & UsageType::PREFETCH;
bool atomic_rmw = Usage[op] & UsageType::ATOMICRMW;
return load && !store && !prefetch && atomic_rmw;
}
return false;
}

bool isForAtomicRMWAndStore(imex::xetile::InitTileOp op) {
if (Usage.count(op)) {
bool load = Usage[op] & UsageType::LOAD;
bool store = Usage[op] & UsageType::STORE;
bool prefetch = Usage[op] & UsageType::PREFETCH;
bool atomic_rmw = Usage[op] & UsageType::ATOMICRMW;
return !load && store && !prefetch && atomic_rmw;
}
return false;
}

private:
enum UsageType {
None = 0,
Expand All @@ -202,7 +237,8 @@ class TileUsageAnalysis {
DPAS_A = 8,
DPAS_B = 16,
DPAS_C = 32,
OTHER = 64
ATOMICRMW = 64,
OTHER = 128
};

llvm::DenseMap<mlir::Operation *, uint> Usage;
Expand Down Expand Up @@ -526,6 +562,12 @@ class XeConversionPattern : public mlir::RewritePattern {
return llvm::cast<TileUsageAnalysis>(analysis).isForPrefetch(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForAtomicRMW(imex::xetile::InitTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForAtomicRMW(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForLoadAndPrefetch(imex::xetile::InitTileOp op) const {
Expand All @@ -537,6 +579,18 @@ class XeConversionPattern : public mlir::RewritePattern {
bool isForLoadAndStore(imex::xetile::InitTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForLoadAndStore(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForLoadAndAtomicRMW(imex::xetile::InitTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForLoadAndAtomicRMW(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForAtomicRMWAndStore(imex::xetile::InitTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForAtomicRMWAndStore(op);
}
};

/// Clone `shape` with the last two elements swapped.
Expand Down
7 changes: 5 additions & 2 deletions lib/Dialect/XeTile/Transforms/InitDuplicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ class XeTileInitDuplicatePass
op->walk([&](imex::xetile::InitTileOp op) {
mlir::OpBuilder rewriter(op);
if (usageAnalysis.isForLoadAndStore(op) ||
usageAnalysis.isForLoadAndPrefetch(op)) {
usageAnalysis.isForLoadAndPrefetch(op) ||
usageAnalysis.isForLoadAndAtomicRMW(op) ||
usageAnalysis.isForAtomicRMWAndStore(op)) {
mlir::Operation *cloneOp = rewriter.clone(*op);
for (auto user : op->getUsers()) {
if (llvm::isa<xetile::StoreTileOp>(user) ||
llvm::dyn_cast<xetile::PrefetchTileOp>(user)) {
llvm::dyn_cast<xetile::PrefetchTileOp>(user) ||
llvm::dyn_cast<xetile::AtomicRMWOp>(user)) {
auto *targetOp = llvm::dyn_cast_if_present<mlir::Operation *>(user);
targetOp->replaceUsesOfWith(op->getResults()[0],
cloneOp->getResults()[0]);
Expand Down
31 changes: 31 additions & 0 deletions test/Dialect/XeTile/Transforms/init_duplicate.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: imex-opt --split-input-file --xetile-init-duplicate %s -verify-diagnostics -o -| FileCheck %s

gpu.module @test_kernel {
//CHECK: gpu.func @init_duplicate(%[[value:.*]]: vector<32x64xf32>, %[[arg0:.*]]: memref<256x256xf32>)
gpu.func @init_duplicate(%value: vector<32x64xf32>, %arg0: memref<256x256xf32>) {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[INITTILE_0:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32>
// CHECK: %[[INITTILE_1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32>
// CHECK: %[[ATOMICRMW:.*]] = xetile.atomic_rmw addf %[[value]], %[[INITTILE_1]] : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32>
// CHECK: xetile.store_tile %[[ATOMICRMW]], %[[INITTILE_0]] : vector<32x64xf32>, !xetile.tile<32x64xf32>
%c0 = arith.constant 0 : index
%tile = xetile.init_tile %arg0[%c0, %c0] : memref<256x256xf32> -> !xetile.tile<32x64xf32>
%rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32>
xetile.store_tile %rmw, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32>
gpu.return
}

//CHECK: gpu.func @init_duplicate_1(%[[value:.*]]: vector<32x64xf32>, %[[arg0:.*]]: memref<256x256xf32>)
gpu.func @init_duplicate_1(%value: vector<32x64xf32>, %arg0: memref<256x256xf32>) {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[INITTILE_0:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32>
// CHECK: %[[INITTILE_1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32>
// CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILE_1]] : !xetile.tile<32x64xf32> -> vector<32x64xf32>
// CHECK: %[[ATOMICRMW:.*]] = xetile.atomic_rmw addf %[[value]], %[[INITTILE_0]] : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32>
%c0 = arith.constant 0 : index
%tile = xetile.init_tile %arg0[%c0, %c0] : memref<256x256xf32> -> !xetile.tile<32x64xf32>
%load = xetile.load_tile %tile : !xetile.tile<32x64xf32> -> vector<32x64xf32>
%rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32>
gpu.return
}
}

0 comments on commit e93172a

Please sign in to comment.