Skip to content

Commit 28a4165

Browse files
committed
[MHAL] Update MHAL to use gpu target attrs: SerializeToHSA to TargetAttr (2/3)
This patch updates the MHAL project to use the `gpu` target attributes infrastructure. This is a patch in part of the series switching the compilation infrastructure from SerializeToHSA to `gpu` target attributes. This includes: - Updating `PrefillPass` to work on GPU binaries. - Updating `PackageTargetsPass` to work on GPU binaries. - Updating `MHALToGPU` to create `gpu.binary` operations. - Updating `mhal::TargetObject` to store an attribute instead of a string. - Add the `DropMetadata` pass: This pass drops all metadata from GPU binaries -e.g. property dictionary and kernel metadata. This is required to avoid unregistered dialect errors for attributes stored in the metadata like `mhal.prefill` but not parsable by tools like `mlir-cpu-runner`.
1 parent 610cb3f commit 28a4165

File tree

10 files changed

+93
-69
lines changed

10 files changed

+93
-69
lines changed

Diff for: external/mlir-hal/include/mlir/Dialect/MHAL/IR/MHALAttrDefs.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def MHAL_TargetObjectAttr : MHAL_Attr<"TargetObject"> {
4040
AttrParameter<"::mlir::mhal::TargetObjectType", "The target object type">:$type,
4141
StringRefParameter<"The architecture target">:$arch,
4242
AttrParameter<"DictionaryAttr", "The object type">:$attributes,
43-
StringRefParameter<"The object binary">:$binary
43+
AttrParameter<"Attribute", "The object binary">:$binary
4444
);
4545

4646
let genVerifyDecl = 0;

Diff for: external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace mhal {
3333
#define GEN_PASS_DECL_MHALSELECTTARGETSPASS
3434
#define GEN_PASS_DECL_MHALBUFFERIZEPASS
3535
#define GEN_PASS_DECL_MHALPREFILLPASS
36+
#define GEN_PASS_DECL_MHALDROPBINARYMETADATAPASS
3637

3738
#define GEN_PASS_REGISTRATION
3839
#include "mlir/Dialect/MHAL/Transforms/Passes.h.inc"

Diff for: external/mlir-hal/include/mlir/Dialect/MHAL/Transforms/Passes.td

+4
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,8 @@ def MHALPrefillPass : Pass<"mhal-prefill", "func::FuncOp"> {
4646
let dependentDialects = ["mhal::MHALDialect"];
4747
}
4848

49+
def MHALDropBinaryMetadataPass : Pass<"mhal-drop-binary-metadata", "ModuleOp"> {
50+
let summary = "drops all metadata stored in GPU binaries";
51+
}
52+
4953
#endif // MLIR_DIALECT_MHAL_PASSES

Diff for: external/mlir-hal/lib/Conversion/MHALToGPU/MHALToGPU.cpp

+12-36
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ struct LaunchRewritePattern : public OpRewritePattern<mhal::LaunchOp> {
172172
if (!kernelPkg.has_value())
173173
return rw.notifyMatchFailure(op, "no gpu target");
174174

175-
auto arch = kernelPkg->getTarget();
176175
auto targetObj = kernelPkg->getObject();
177176
auto binary = targetObj.getBinary();
178177
auto launchDims = kernelPkg->getLaunchDims();
@@ -184,46 +183,20 @@ struct LaunchRewritePattern : public OpRewritePattern<mhal::LaunchOp> {
184183
auto func = *getCalledFunc(op);
185184
Location floc = func.getLoc();
186185

187-
// 2. create dummy gpu.module for reference from gpu.launch_func
188-
// - with gpu.binary, arch attributes
189-
// - and gpu.func (referenced by gpu.launch_func
190-
// gpu.module @<func_name>_module attributes {arch = "gfx908", gpu.binary
191-
// = "\7FELF\..."} {
192-
// gpu.func @<func_name> (...) attributes {block_size = 256 : i32,
193-
// grid_size = 900 : i32, gpu.kernel}
186+
// 2. re-materialize gpu.binary @<func_name>_module [#gpu.object<...>]
194187

195188
FunctionOpInterface funcIF(func);
196189
auto funcName = funcIF.getName();
197-
auto gpuModuleName = funcName + "_module";
190+
auto binaryName = funcName + "_module";
198191

199-
auto gpuModule = module.lookupSymbol<gpu::GPUModuleOp>(gpuModuleName.str());
200-
if (!gpuModule) {
192+
auto binaryOp = module.lookupSymbol<gpu::BinaryOp>(binaryName.str());
193+
if (!binaryOp) {
201194
OpBuilder b(ctx);
202-
gpuModule = b.create<gpu::GPUModuleOp>(floc, gpuModuleName.str());
203-
gpuModule->setAttr("arch", b.getStringAttr(arch));
204-
gpuModule->setAttr("gpu.binary", b.getStringAttr(binary));
195+
binaryOp = b.create<gpu::BinaryOp>(floc, binaryName.str(), nullptr,
196+
ArrayRef<Attribute>({binary}));
205197

206198
SymbolTable symbolTable(module);
207-
symbolTable.insert(gpuModule);
208-
}
209-
210-
auto gpuFunc = gpuModule.lookupSymbol<gpu::GPUFuncOp>(funcName);
211-
if (!gpuFunc) {
212-
OpBuilder b(gpuModule.getContext());
213-
gpuFunc =
214-
b.create<gpu::GPUFuncOp>(floc, funcName, func.getFunctionType());
215-
gpuFunc->setAttr("block_size", b.getI32IntegerAttr(blockSize));
216-
gpuFunc->setAttr("grid_size", b.getI32IntegerAttr(gridSize));
217-
gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
218-
b.getUnitAttr());
219-
220-
SymbolTable symbolTable(gpuModule);
221-
symbolTable.insert(gpuFunc);
222-
223-
// Must have a return
224-
auto block = &gpuFunc.front();
225-
b.setInsertionPoint(block, block->begin());
226-
b.create<gpu::ReturnOp>(floc, ValueRange{});
199+
symbolTable.insert(binaryOp);
227200
}
228201

229202
// 3. create substitute gpu.launch_func
@@ -281,9 +254,12 @@ struct LaunchRewritePattern : public OpRewritePattern<mhal::LaunchOp> {
281254

282255
// Make gpu.launch_func
283256
auto gpuLaunchOp = rw.create<gpu::LaunchFuncOp>(
284-
loc, gpuFunc, gpu::KernelDim3{gridSizeIdx, oneIdx, oneIdx},
257+
loc,
258+
SymbolRefAttr::get(getContext(), binaryName.str(),
259+
{FlatSymbolRefAttr::get(getContext(), funcName)}),
260+
gpu::KernelDim3{gridSizeIdx, oneIdx, oneIdx},
285261
gpu::KernelDim3{blockSizeIdx, oneIdx, oneIdx}, dynamicSharedMemorySize,
286-
gpuOperands, tokenType, asyncDeps);
262+
gpuOperands, tokenType, ValueRange(asyncDeps));
287263
Value token = gpuLaunchOp->getResult(0);
288264

289265
// Insert gpu.memcpy for results

Diff for: external/mlir-hal/lib/Dialect/MHAL/IR/MHAL.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ mlir::Attribute TargetObjectAttr::parse(mlir::AsmParser &parser,
9999
return {};
100100
}
101101

102-
std::string binary;
103-
if (parser.parseKeywordOrString(&binary)) {
102+
Attribute binary;
103+
if (parser.parseAttribute(binary)) {
104104
return {};
105105
}
106106

@@ -129,7 +129,7 @@ void TargetObjectAttr::print(mlir::AsmPrinter &printer) const {
129129

130130
// print binary
131131
printer << " -> ";
132-
printer.printKeywordOrString(getBinary());
132+
printer << getBinary();
133133
printer << ">";
134134
}
135135

Diff for: external/mlir-hal/lib/Dialect/MHAL/Pipelines/Pipelines.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ void mhal::buildRunnerPipeline(OpPassManager &pm,
150150
GpuToLLVMConversionPassOptions opts;
151151
opts.kernelBarePtrCallConv = options.barePtrMemrefs;
152152
pm.addPass(createGpuToLLVMConversionPass(opts));
153+
pm.addPass(createMHALDropBinaryMetadataPass());
153154

154155
pm.addPass(createConvertFuncToLLVMPass());
155156
pm.addPass(createReconcileUnrealizedCastsPass());

Diff for: external/mlir-hal/lib/Dialect/MHAL/Transforms/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRMHALTransforms
22
Bufferize.cpp
33
BufferizableOpInterfaceImpl.cpp
4+
DropMetadata.cpp
45
InferGraph.cpp
56
PackageTargets.cpp
67
SelectTargets.cpp
@@ -26,4 +27,3 @@ add_mlir_dialect_library(MLIRMHALTransforms
2627
MLIRSupport
2728
MLIRTransformUtils
2829
)
29-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3+
#include "mlir/Dialect/MHAL/Transforms/Passes.h"
4+
#include "mlir/IR/Builders.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
7+
namespace mlir {
8+
namespace mhal {
9+
#define GEN_PASS_DEF_MHALDROPBINARYMETADATAPASS
10+
#include "mlir/Dialect/MHAL/Transforms/Passes.h.inc"
11+
} // namespace mhal
12+
} // namespace mlir
13+
14+
#define DEBUG_TYPE "mhal-prefill"
15+
16+
using namespace mlir;
17+
18+
namespace {
19+
class MHALDropBinaryMetadataPass
20+
: public mhal::impl::MHALDropBinaryMetadataPassBase<
21+
MHALDropBinaryMetadataPass> {
22+
public:
23+
// Inspect each gpu::BinaryOp and drop all the metadata.
24+
void runOnOperation() override;
25+
};
26+
} // namespace
27+
28+
// Inspect each gpu::BinaryOp and drop all the metadata.
29+
void MHALDropBinaryMetadataPass::runOnOperation() {
30+
Builder b(&getContext());
31+
for (gpu::BinaryOp binary :
32+
getOperation().getBody()->getOps<gpu::BinaryOp>()) {
33+
// Drop all discardable attributes.
34+
binary->setDiscardableAttrs(b.getDictionaryAttr({}));
35+
SmallVector<Attribute, 10> objects;
36+
for (auto objRaw : binary.getObjects()) {
37+
auto object = cast<gpu::ObjectAttr>(objRaw);
38+
// Drop the property dictionary.
39+
objects.push_back(
40+
b.getAttr<gpu::ObjectAttr>(object.getTarget(), object.getFormat(),
41+
object.getObject(), nullptr, nullptr));
42+
}
43+
binary.setObjectsAttr(b.getArrayAttr(objects));
44+
}
45+
}

Diff for: external/mlir-hal/lib/Dialect/MHAL/Transforms/PackageTargets.cpp

+15-21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Dialect/Func/IR/FuncOps.h"
2323
#include "mlir/Dialect/GPU/Transforms/Passes.h"
2424
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
25+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
2526
#include "mlir/Dialect/MHAL/IR/MHAL.h"
2627
#include "mlir/Dialect/MHAL/Transforms/Passes.h"
2728
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -41,7 +42,6 @@ namespace mhal {
4142
using namespace mlir;
4243

4344
namespace {
44-
4545
struct MHALPackageTargetsPass
4646
: public mhal::impl::MHALPackageTargetsPassBase<MHALPackageTargetsPass> {
4747

@@ -55,36 +55,30 @@ struct MHALPackageTargetsPass
5555

5656
mod->walk([&](ModuleOp kernelMod) {
5757
if (kernelMod->hasAttr("mhal.module")) {
58-
SmallVector<gpu::GPUModuleOp, 8> gpuMods;
59-
kernelMod->walk([&](gpu::GPUModuleOp gpuMod) {
60-
auto binaryAttr = gpuMod->getAttrOfType<StringAttr>(
61-
gpu::getDefaultGpuBinaryAnnotation());
62-
if (!binaryAttr) {
63-
gpuMod.emitOpError() << "missing gpu.binary attribute";
64-
return;
65-
}
66-
67-
gpuMods.push_back(gpuMod);
68-
58+
SmallVector<gpu::BinaryOp, 8> binaries;
59+
kernelMod->walk([&](gpu::BinaryOp binary) {
60+
auto object = cast<gpu::ObjectAttr>(binary.getObjects()[0]);
61+
binaries.push_back(binary);
62+
gpu::KernelTableAttr metadata = object.getKernels();
63+
assert(metadata && "expected a valid metadata attribute");
6964
// apply target spec to original func
70-
gpuMod.walk([&](LLVM::LLVMFuncOp func) {
71-
if (auto attr =
72-
func->getAttrOfType<SymbolRefAttr>("original_func")) {
65+
for (auto [name, kernel] : metadata) {
66+
if (auto attr = kernel.getAttr<SymbolRefAttr>("original_func")) {
7367
if (auto kernelFunc = mod.lookupSymbol<func::FuncOp>(attr)) {
7468
auto archName =
7569
kernelMod->getAttrOfType<StringAttr>("mhal.arch")
7670
.getValue();
7771
auto funcName = attr.getLeafReference().getValue();
7872
uint32_t gridSize =
79-
func->getAttrOfType<IntegerAttr>("grid_size").getInt();
73+
kernel.getAttr<IntegerAttr>("grid_size").getInt();
8074
uint32_t blockSize =
81-
func->getAttrOfType<IntegerAttr>("block_size").getInt();
75+
kernel.getAttr<IntegerAttr>("block_size").getInt();
8276

8377
DictionaryAttr objAttrs;
8478

8579
auto xobj = mhal::TargetObjectAttr::get(
8680
b.getContext(), mhal::TargetObjectType::ELF, archName,
87-
objAttrs, binaryAttr);
81+
objAttrs, object);
8882

8983
DictionaryAttr pkgAttrs;
9084
// = b.getDictionaryAttr({
@@ -97,12 +91,12 @@ struct MHALPackageTargetsPass
9791
kernelImpls[kernelFunc].push_back(xpkg);
9892
}
9993
}
100-
});
94+
}
10195
});
10296

10397
// clean processed gpu.modules
104-
for (auto gpuMod : gpuMods) {
105-
gpuMod.erase();
98+
for (auto binary : binaries) {
99+
binary.erase();
106100
}
107101

108102
// remove __kernel_*

Diff for: external/mlir-hal/lib/Dialect/MHAL/Transforms/Prefill.cpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@ void MHALPrefillPass::insertPrefillOps(OpBuilder &builder,
4848
gpu::LaunchFuncOp &launchOp) {
4949
auto func = cast<func::FuncOp>(launchOp->getParentOp());
5050
auto module = cast<ModuleOp>(func->getParentOp());
51-
auto kernel = launchOp.getKernel();
52-
auto *callee = module.lookupSymbol(kernel);
53-
assert(callee != nullptr && "expect to find the function defenition");
54-
auto llvmFunc = cast<LLVM::LLVMFuncOp>(callee);
55-
auto gpuModule = cast<gpu::GPUModuleOp>(llvmFunc->getParentOp());
56-
51+
auto binaryName = launchOp.getKernelModuleName();
52+
auto binary = module.lookupSymbol<gpu::BinaryOp>(binaryName);
53+
assert(binary != nullptr && "expect to find the function defenition");
54+
auto objects = binary.getObjects().getValue();
55+
assert(objects.size() == 1 && "expected a single object");
5756
SmallVector<mhal::PrefillAttr, 4> prefillAttrs;
58-
if (auto moduleAttr = gpuModule->getAttr(llvmFunc.getSymName())) {
57+
auto object = cast<gpu::ObjectAttr>(objects[0]);
58+
DictionaryAttr objectProps = object.getProperties();
59+
if (!objectProps)
60+
return;
61+
if (auto moduleAttr = objectProps.get(launchOp.getKernelName())) {
5962
if (auto arrayAttr = dyn_cast<ArrayAttr>(moduleAttr)) {
6063
for (auto attr : arrayAttr) {
6164
if (auto prefillAttr = dyn_cast<mhal::PrefillAttr>(attr)) {

0 commit comments

Comments
 (0)