Skip to content

Commit 47f4f39

Browse files
authored
[MLIR][AMDGPU] Fixing word alignment check for bufferload fastpath (#135982)
`delta_bytes % (32 ceilDiv elementBitwidth) != 0` condition is incorrect in #135014 For example, last load is issued to load only one last element of fp16. Then `delta bytes = 2`, `(32 ceildiv 16) = 2`. In this case it will be judged as word aligned. It will send to fast path but get all zeros for the fp16 because it cross the word boundary. In reality the equation should be just `delta_bytes % 4` , since a word is 4 bytes. This PR fix the bug by amending the mod target to 4.
1 parent 5a99355 commit 47f4f39

File tree

3 files changed

+10
-17
lines changed

3 files changed

+10
-17
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

+2-7
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
14-
#include "mlir/Dialect/Arith/Utils/Utils.h"
1514
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1615
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1716
#include "mlir/Dialect/SCF/IR/SCF.h"
1817
#include "mlir/Dialect/Vector/IR/VectorOps.h"
19-
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2018
#include "mlir/IR/BuiltinTypes.h"
2119
#include "mlir/IR/OpDefinition.h"
2220
#include "mlir/IR/PatternMatch.h"
@@ -225,15 +223,12 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
225223
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
226224
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
227225

228-
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
229-
Value deltaBytes = rewriter.create<arith::MulIOp>(
230-
loc, delta,
231-
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
226+
// 2) check if (detla % elements_per_word != 0)
232227
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
233228
loc, llvm::divideCeil(32, elementBitWidth));
234229
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
235230
loc, arith::CmpIPredicate::ne,
236-
rewriter.create<arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
231+
rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
237232
rewriter.create<arith::ConstantIndexOp>(loc, 0));
238233

239234
// We take the fallback of transfer_read default lowering only it is both

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

+6-10
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
1010
return %res : vector<4xf32>
1111
}
1212

13-
// CHECK: %[[FALSE:.*]] = arith.constant false
14-
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
13+
// CHECK: %[[IF:.*]] = scf.if
1514
// CHECK: vector.transfer_read %[[ARG0]][%[[ARG1]], %[[ARG1]]]
1615

1716
// CHECK: } else {
@@ -35,14 +34,13 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
3534
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
3635
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 64
3736
// CHECK-DAG: %[[BYTES:.*]] = arith.constant 2
38-
// CHECK-DAG: %[[VECTORSIZE:.*]] = arith.constant 4
37+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4
3938

4039
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]]
4140
// CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]]
42-
// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[VECTORSIZE]]
41+
// CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[C4]]
4342

44-
// CHECK: %[[DELTABYTES:.*]] = arith.muli %[[DELTA]], %[[BYTES]]
45-
// CHECK: %[[REM:.*]] = arith.remui %[[DELTABYTES]], %[[BYTES]]
43+
// CHECK: %[[REM:.*]] = arith.remui %[[DELTA]], %[[BYTES]]
4644
// CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]]
4745

4846
// CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]]
@@ -120,8 +118,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
120118
return %res : vector<4xf32>
121119
}
122120
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
123-
// CHECK: %[[FALSE:.*]] = arith.constant false
124-
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<4xf32>) {
121+
// CHECK: %[[IF:.*]] = scf.if
125122
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
126123
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]
127124
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
@@ -140,7 +137,6 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
140137
return %res : vector<1xf32>
141138
}
142139
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
143-
// CHECK: %[[FALSE:.*]] = arith.constant false
144-
// CHECK: %[[IF:.*]] = scf.if %[[FALSE]] -> (vector<1xf32>) {
140+
// CHECK: %[[IF:.*]] = scf.if
145141
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
146142
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[CST]]

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

+2
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,7 @@ cc_library(
15591559
hdrs = glob(["include/mlir/Dialect/AMDGPU/Transforms/*.h"]),
15601560
includes = ["include"],
15611561
deps = [
1562+
":AffineDialect",
15621563
":AMDGPUDialect",
15631564
":AMDGPUPassIncGen",
15641565
":AMDGPUUtils",
@@ -1569,6 +1570,7 @@ cc_library(
15691570
":FuncDialect",
15701571
":GPUDialect",
15711572
":IR",
1573+
":LLVMSupportHeaders",
15721574
":MemRefDialect",
15731575
":MemRefUtils",
15741576
":Pass",

0 commit comments

Comments
 (0)