@@ -45,14 +45,19 @@ bool hasAssignedMemSpace(Value value) {
45
45
return false ;
46
46
}
47
47
48
+ // Converts `memref::AllocOp` within GPU regions to the GPU shared local
49
+ // memory. Adjusts the allocation shape based on GPU block dimensions and
50
+ // creates a `memref::SubViewOp` for thread-specific memory access.
48
51
struct ConvertAlloc : public OpRewritePattern <memref::AllocOp> {
49
52
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
50
53
51
54
ConvertAlloc (MLIRContext *ctx) : OpRewritePattern<memref::AllocOp>(ctx) {}
52
55
53
56
LogicalResult matchAndRewrite (memref::AllocOp allocOp,
54
57
PatternRewriter &rewriter) const override {
55
- if (hasAssignedMemSpace (allocOp->getResult (0 ))) {
58
+ Value memref = allocOp->getResult (0 );
59
+
60
+ if (hasAssignedMemSpace (memref)) {
56
61
return rewriter.notifyMatchFailure (
57
62
allocOp, " Memref already has some memory space attribute" );
58
63
}
@@ -62,22 +67,83 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
62
67
" Only support allocs in GPU regions" );
63
68
}
64
69
65
- Value memref = allocOp->getResult (0 );
70
+ auto launchOp = allocOp->getParentOfType <gpu::LaunchOp>();
71
+
72
+ auto xSz = dyn_cast<arith::ConstantIndexOp>(
73
+ launchOp.getBlockSizeX ().getDefiningOp ());
74
+ auto ySz = dyn_cast<arith::ConstantIndexOp>(
75
+ launchOp.getBlockSizeY ().getDefiningOp ());
76
+ auto zSz = dyn_cast<arith::ConstantIndexOp>(
77
+ launchOp.getBlockSizeZ ().getDefiningOp ());
78
+
79
+ if (!xSz || !ySz || !zSz)
80
+ return rewriter.notifyMatchFailure (
81
+ allocOp, " Only support constant block sizes for now" );
82
+
83
+ int64_t xI = xSz.value ();
84
+ int64_t yI = ySz.value ();
85
+ int64_t zI = zSz.value ();
86
+
87
+ if (zI != 1 )
88
+ return rewriter.notifyMatchFailure (
89
+ allocOp, " Only support 2D shared memory for now" );
90
+
66
91
MemRefType originalMemRefType = cast<MemRefType>(memref.getType ());
92
+ auto originalShape = originalMemRefType.getShape ();
93
+
94
+ // Scale the allocation size by the number of threads in the work-group
95
+ int64_t newX = originalShape[0 ] * xI;
96
+ int64_t newY = originalShape[1 ] * yI;
97
+
98
+ SmallVector<int64_t > newShape = {newX, newY};
67
99
68
100
IntegerAttr sharedAddressSpace =
69
101
IntegerAttr::get (rewriter.getIntegerType (64 ),
70
102
static_cast <int64_t >(gpu::AddressSpace::Private));
71
103
72
- // Create a new MemRefType with the desired address space
73
- MemRefType newMemRefType = MemRefType::get (
74
- originalMemRefType.getShape (), originalMemRefType.getElementType (),
75
- originalMemRefType.getLayout (), sharedAddressSpace);
76
-
77
- Value newMemRef = rewriter.create <memref::AllocOp>(
78
- allocOp.getLoc (), newMemRefType, allocOp.getOperands ());
79
-
80
- memref.replaceAllUsesWith (newMemRef);
104
+ MemRefType newRootMemRefType =
105
+ MemRefType::get (newShape, originalMemRefType.getElementType (),
106
+ originalMemRefType.getLayout (), sharedAddressSpace);
107
+
108
+ Value newRootMemRef =
109
+ rewriter
110
+ .create <memref::AllocOp>(allocOp.getLoc (), newRootMemRefType,
111
+ allocOp.getOperands ())
112
+ .getResult ();
113
+
114
+ // Compute the offsets in SLM chunk for the current thread
115
+ auto origXConst = rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (),
116
+ originalShape[0 ]);
117
+ auto origYConst = rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (),
118
+ originalShape[1 ]);
119
+
120
+ auto threadIds = launchOp.getThreadIds ();
121
+
122
+ auto offX =
123
+ rewriter
124
+ .create <arith::MulIOp>(allocOp.getLoc (), threadIds.x , origXConst)
125
+ .getResult ();
126
+ auto offY =
127
+ rewriter
128
+ .create <arith::MulIOp>(allocOp.getLoc (), threadIds.y , origYConst)
129
+ .getResult ();
130
+
131
+ auto offsets = getMixedValues ({ShapedType::kDynamic , ShapedType::kDynamic },
132
+ {offX, offY}, rewriter);
133
+ auto sizes = getMixedValues (originalShape, {}, rewriter);
134
+ auto strides = getMixedValues ({1 , 1 }, {}, rewriter);
135
+
136
+ auto newSlice =
137
+ rewriter
138
+ .create <memref::SubViewOp>(allocOp.getLoc (), newRootMemRef, offsets,
139
+ sizes, strides)
140
+ .getResult ();
141
+ memref.replaceAllUsesWith (newSlice);
142
+
143
+ // Erase deallocs since we don't need them for SLM
144
+ for (auto user : newSlice.getUsers ())
145
+ if (auto deallocOp = dyn_cast<memref::DeallocOp>(user))
146
+ deallocOp->erase ();
81
147
82
148
return success ();
83
149
}
0 commit comments