Skip to content

Commit d9921c4

Browse files
authored
[AllocsToSLM] Add thread-specific offsets (#407)
Signed-off-by: dchigarev <[email protected]>
1 parent 1377cc3 commit d9921c4

File tree

2 files changed

+104
-25
lines changed

2 files changed

+104
-25
lines changed

lib/gc/Transforms/GPU/AllocsToSLM.cpp

+77-11
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,19 @@ bool hasAssignedMemSpace(Value value) {
4545
return false;
4646
}
4747

48+
// Converts `memref::AllocOp` within GPU regions to the GPU shared local
49+
// memory. Adjusts the allocation shape based on GPU block dimensions and
50+
// creates a `memref::SubViewOp` for thread-specific memory access.
4851
struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
4952
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
5053

5154
ConvertAlloc(MLIRContext *ctx) : OpRewritePattern<memref::AllocOp>(ctx) {}
5255

5356
LogicalResult matchAndRewrite(memref::AllocOp allocOp,
5457
PatternRewriter &rewriter) const override {
55-
if (hasAssignedMemSpace(allocOp->getResult(0))) {
58+
Value memref = allocOp->getResult(0);
59+
60+
if (hasAssignedMemSpace(memref)) {
5661
return rewriter.notifyMatchFailure(
5762
allocOp, "Memref already has some memory space attribute");
5863
}
@@ -62,22 +67,83 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
6267
"Only support allocs in GPU regions");
6368
}
6469

65-
Value memref = allocOp->getResult(0);
70+
auto launchOp = allocOp->getParentOfType<gpu::LaunchOp>();
71+
72+
auto xSz = dyn_cast<arith::ConstantIndexOp>(
73+
launchOp.getBlockSizeX().getDefiningOp());
74+
auto ySz = dyn_cast<arith::ConstantIndexOp>(
75+
launchOp.getBlockSizeY().getDefiningOp());
76+
auto zSz = dyn_cast<arith::ConstantIndexOp>(
77+
launchOp.getBlockSizeZ().getDefiningOp());
78+
79+
if (!xSz || !ySz || !zSz)
80+
return rewriter.notifyMatchFailure(
81+
allocOp, "Only support constant block sizes for now");
82+
83+
int64_t xI = xSz.value();
84+
int64_t yI = ySz.value();
85+
int64_t zI = zSz.value();
86+
87+
if (zI != 1)
88+
return rewriter.notifyMatchFailure(
89+
allocOp, "Only support 2D shared memory for now");
90+
6691
MemRefType originalMemRefType = cast<MemRefType>(memref.getType());
92+
auto originalShape = originalMemRefType.getShape();
93+
94+
// Scale the allocation size by the number of threads in the work-group
95+
int64_t newX = originalShape[0] * xI;
96+
int64_t newY = originalShape[1] * yI;
97+
98+
SmallVector<int64_t> newShape = {newX, newY};
6799

68100
IntegerAttr sharedAddressSpace =
69101
IntegerAttr::get(rewriter.getIntegerType(64),
70102
static_cast<int64_t>(gpu::AddressSpace::Private));
71103

72-
// Create a new MemRefType with the desired address space
73-
MemRefType newMemRefType = MemRefType::get(
74-
originalMemRefType.getShape(), originalMemRefType.getElementType(),
75-
originalMemRefType.getLayout(), sharedAddressSpace);
76-
77-
Value newMemRef = rewriter.create<memref::AllocOp>(
78-
allocOp.getLoc(), newMemRefType, allocOp.getOperands());
79-
80-
memref.replaceAllUsesWith(newMemRef);
104+
MemRefType newRootMemRefType =
105+
MemRefType::get(newShape, originalMemRefType.getElementType(),
106+
originalMemRefType.getLayout(), sharedAddressSpace);
107+
108+
Value newRootMemRef =
109+
rewriter
110+
.create<memref::AllocOp>(allocOp.getLoc(), newRootMemRefType,
111+
allocOp.getOperands())
112+
.getResult();
113+
114+
// Compute the offsets in SLM chunk for the current thread
115+
auto origXConst = rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(),
116+
originalShape[0]);
117+
auto origYConst = rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(),
118+
originalShape[1]);
119+
120+
auto threadIds = launchOp.getThreadIds();
121+
122+
auto offX =
123+
rewriter
124+
.create<arith::MulIOp>(allocOp.getLoc(), threadIds.x, origXConst)
125+
.getResult();
126+
auto offY =
127+
rewriter
128+
.create<arith::MulIOp>(allocOp.getLoc(), threadIds.y, origYConst)
129+
.getResult();
130+
131+
auto offsets = getMixedValues({ShapedType::kDynamic, ShapedType::kDynamic},
132+
{offX, offY}, rewriter);
133+
auto sizes = getMixedValues(originalShape, {}, rewriter);
134+
auto strides = getMixedValues({1, 1}, {}, rewriter);
135+
136+
auto newSlice =
137+
rewriter
138+
.create<memref::SubViewOp>(allocOp.getLoc(), newRootMemRef, offsets,
139+
sizes, strides)
140+
.getResult();
141+
memref.replaceAllUsesWith(newSlice);
142+
143+
// Erase deallocs since we don't need them for SLM
144+
for (auto user : newSlice.getUsers())
145+
if (auto deallocOp = dyn_cast<memref::DeallocOp>(user))
146+
deallocOp->erase();
81147

82148
return success();
83149
}

test/mlir/test/gc/Transforms/GPU/allocs-to-slm.mlir

+27-14
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,38 @@
22

33
func.func @entry() {
44
%c1 = arith.constant 1 : index
5+
%c2 = arith.constant 2 : index
6+
%c4 = arith.constant 4 : index
57

68
// Memory space wasn't assigned as it's allocated outside of gpu.launch block
7-
// CHECK: %[[NEW_MEMREF_0:.*]] = memref.alloc() : memref<16x16xf16>
8-
%0 = memref.alloc() : memref<16x16xf16>
9-
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c1, %sz_by = %c1, %sz_bz = %c1)
10-
threads(%tx, %ty, %tz) in (%sz_tx = %c1, %sz_ty = %c1, %sz_tz = %c1) {
9+
// CHECK: %[[NEW_MEMREF_0:.*]] = memref.alloc() : memref<16x32xf16>
10+
%0 = memref.alloc() : memref<16x32xf16>
11+
// Capture thread-id variables
12+
// CHECK: gpu.launch blocks(%[[ARG0:.+]], %[[ARG1:.+]], %[[ARG2:.+]]) in (%[[ARG6:.+]] = %c2, %[[ARG7:.+]] = %c2, %[[ARG8:.+]] = %c1) threads
13+
// CHECK-SAME: (%[[THREAD_X:.+]], %[[THREAD_Y:.+]], %[[ARG5:.+]]) in
14+
// CHECK-SAME: (%[[ARG9:.+]] = %c2, %[[ARG10:.+]] = %c4, %[[ARG11:.+]] = %c1) {
15+
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c2, %sz_by = %c2, %sz_bz = %c1)
16+
threads(%tx, %ty, %tz) in (%sz_tx = %c2, %sz_ty = %c4, %sz_tz = %c1) {
1117
// Memory space was changed as it's explicitly specifided
12-
// CHECK: %[[NEW_MEMREF_1:.*]] = memref.alloc() : memref<16x16xf16, 1>
13-
%1 = memref.alloc() : memref<16x16xf16, 1>
18+
// CHECK: %[[NEW_MEMREF_1:.*]] = memref.alloc() : memref<16x32xf16, 1>
19+
%1 = memref.alloc() : memref<16x32xf16, 1>
1420
// Added 'shared' memory space
15-
// CHECK: %[[NEW_MEMREF_2:.*]] = memref.alloc() : memref<16x16xf16, 3>
16-
%2 = memref.alloc() : memref<16x16xf16>
21+
// CHECK: %[[NEW_MEMREF_2:.*]] = memref.alloc() : memref<32x128xf16, 3>
22+
// CHECK: %[[OFF_X:.*]] = arith.muli %[[THREAD_X]], %c16 : index
23+
// CHECK: %[[OFF_Y:.*]] = arith.muli %[[THREAD_Y]], %c32 : index
24+
// CHECK: %[[NEW_MEMREF_3:.*]] = memref.subview %[[NEW_MEMREF_2]][%[[OFF_X]], %[[OFF_Y]]] [16, 32] [1, 1]
25+
// CHECK-SAME: memref<32x128xf16, 3> to memref<16x32xf16, strided<[128, 1], offset: ?>, 3>
26+
%2 = memref.alloc() : memref<16x32xf16>
1727

18-
// CHECK: linalg.add ins(%[[NEW_MEMREF_1]], %[[NEW_MEMREF_2]] : memref<16x16xf16, 1>, memref<16x16xf16, 3>) outs(%[[NEW_MEMREF_0]] : memref<16x16xf16>)
19-
linalg.add ins(%1, %2 :memref<16x16xf16, 1>, memref<16x16xf16>) outs(%0 : memref<16x16xf16>)
20-
// CHECK: memref.dealloc %[[NEW_MEMREF_1]] : memref<16x16xf16, 1>
21-
// CHECK: memref.dealloc %[[NEW_MEMREF_2]] : memref<16x16xf16, 3>
22-
memref.dealloc %1 : memref<16x16xf16, 1>
23-
memref.dealloc %2 : memref<16x16xf16>
28+
// CHECK: linalg.add ins(%[[NEW_MEMREF_1]], %[[NEW_MEMREF_3]] :
29+
// CHECK-SAME: memref<16x32xf16, 1>, memref<16x32xf16, strided<[128, 1], offset: ?>, 3>) outs(%[[NEW_MEMREF_0]] : memref<16x32xf16>)
30+
linalg.add ins(%1, %2 :memref<16x32xf16, 1>, memref<16x32xf16>) outs(%0 : memref<16x32xf16>)
31+
// CHECK: memref.dealloc %[[NEW_MEMREF_1]] : memref<16x32xf16, 1>
32+
// Verify that there are no deallocs for SLM
33+
// CHECK-NOT: memref.dealloc %[[NEW_MEMREF_2]] .*
34+
// CHECK-NOT: memref.dealloc %[[NEW_MEMREF_3]] .*
35+
memref.dealloc %1 : memref<16x32xf16, 1>
36+
memref.dealloc %2 : memref<16x32xf16>
2437
gpu.terminator
2538
}
2639
return

0 commit comments

Comments
 (0)