Skip to content

Commit

Permalink
Add gpu.printf lowering in GPUToSPIRV pass along with SPIRV patch to …
Browse files Browse the repository at this point in the history
…fix spirv dialect and serialization for spirv.CL.printf

Changes include:
    - gpu.printf op conversion lowering in GPUToSPIRV pass that lowers to spirv.CL.printf op
    - Adds lit test as well as e2e test case to verify the lowering through imex-convert-gpu-to-spirv pass
    - Adds a PATCH that fixes upstream MLIR spirv dialect to support SpecConstantComposite as an initializer
      for spirv.GlobalVariable op. patch "0001-SPIRV-add-SpecConstantComposite-Op-support-in-Global.patch" is
      added until spirv dialect fix is upstreamed.
  • Loading branch information
drprajap committed Dec 13, 2023
1 parent 8a87398 commit ad25087
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
From fb3fb52995e3441aa8f8abc14006eda6a599c94b Mon Sep 17 00:00:00 2001
From: "Prajapati, Dimple" <[email protected]>
Date: Fri, 1 Dec 2023 12:57:17 -0800
Subject: [PATCH] SPIRV: add SpecConstantComposite Op support in GlobalVarOp

---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 5 +++--
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp | 11 ++++++++---
2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea72..b39f1607ad7f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1163,9 +1163,10 @@ LogicalResult spirv::GlobalVariableOp::verify() {
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
if (!initOp ||
- !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
+ !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
- "spirv.SpecConstant or spirv.GlobalVariable op");
+ "spirv.SpecConstant or spirv.GlobalVariable op or "
+ "spirv.SpecConstantCompositeOp");
}
}

diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b..cc968c0627e9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -382,13 +382,18 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Encode StorageClass.
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));

+ // Encode initialization.
+
// Encode initialization.
if (auto initializer = varOp.getInitializer()) {
auto initializerID = getVariableID(*initializer);
+
if (!initializerID) {
- return emitError(varOp.getLoc(),
- "invalid usage of undefined variable as initializer");
- }
+ initializerID = getSpecConstID(*initializer);
+ if (!initializerID)
+ return emitError(varOp.getLoc(),
+ "invalid usage of undefined variable as initializer");
+ }
operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
}
--
2.34.1
9 changes: 9 additions & 0 deletions include/imex/Conversion/GPUToSPIRV/GPUToSPIRVPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@

#ifndef IMEX_GPUTOSPIRV_PASS_H_
#define IMEX_GPUTOSPIRV_PASS_H_
#include <mlir/Dialect/SPIRV/IR/SPIRVDialect.h>
#include <mlir/Dialect/SPIRV/IR/SPIRVOps.h>
#include <mlir/Transforms/DialectConversion.h>

#include <memory>

namespace mlir {
class SPIRVTypeConverter;
class RewritePatternSet;
class Pass;
struct ScfToSPIRVContextImpl;
class ModuleOp;
Expand All @@ -26,6 +31,10 @@ template <typename T> class OperationPass;
} // namespace mlir

namespace imex {

void populateGPUPrintfToSPIRVPatterns(mlir::SPIRVTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns);

/// Create a pass
std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>>
createConvertGPUXToSPIRVPass(bool mapMemorySpace = true);
Expand Down
122 changes: 122 additions & 0 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"

#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -59,6 +68,117 @@ class GPUXToSPIRVPass : public ::imex::ConvertGPUXToSPIRVBase<GPUXToSPIRVPass> {
bool mapMemorySpace;
};

class PrintfOpPattern : public mlir::OpConversionPattern<mlir::gpu::PrintfOp> {
public:
using mlir::OpConversionPattern<mlir::gpu::PrintfOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto loc = gpuPrintfOp.getLoc();

auto funcOp = rewriter.getBlock()
->getParent()
->getParentOfType<mlir::spirv::FuncOp>();

auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();

const char formatStringPrefix[] = "printfMsg";
unsigned stringNumber = 0;
mlir::SmallString<16> globalVarName;
mlir::spirv::GlobalVariableOp globalVar;

// formulate spirv global variable name
do {
globalVarName.clear();
(formatStringPrefix + llvm::Twine(stringNumber++))
.toStringRef(globalVarName);
} while (moduleOp.lookupSymbol(globalVarName));

auto i8Type = rewriter.getI8Type();
auto i32Type = rewriter.getI32Type();

unsigned scNum = 0;
auto createSpecConstant = [&](unsigned value) {
auto attr = rewriter.getI8IntegerAttr(value);
mlir::SmallString<16> specCstName;
(llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
.toStringRef(specCstName);

return rewriter.create<mlir::spirv::SpecConstantOp>(
loc, rewriter.getStringAttr(specCstName), attr);
};

// define GlobalVarOp with printf format string using SpecConstants
// and make composite of SpecConstants
{
mlir::Operation *parent =
mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());

mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);

mlir::Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(
&entryBlock); // insertion point at module level

// Create Constituents with SpecConstant to construct
// SpecConstantCompositeOp
llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
formatString.push_back('\0'); // Null terminate for C
mlir::SmallVector<mlir::Attribute, 4> constituents;
for (auto c : formatString) {
auto cSpecConstantOp = createSpecConstant(c);
constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
}

// Create specialization constant composite defined via spirv.SpecConstant
size_t contentSize = constituents.size();
auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
mlir::spirv::SpecConstantCompositeOp specCstComposite;
mlir::SmallString<16> specCstCompositeName;
(llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
loc, mlir::TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));

// Define GlobalVariable initialized from Constant Composite
globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
loc,
mlir::spirv::PointerType::get(
globalType, mlir::spirv::StorageClass::UniformConstant),
globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
}

// Get SSA value of Global variable
mlir::Value globalPtr =
rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);

mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
loc,
mlir::spirv::PointerType::get(
i8Type, mlir::spirv::StorageClass::UniformConstant),
globalPtr);

// Get printf arguments
auto argsRange = adaptor.getArgs();
mlir::SmallVector<mlir::Value, 4> printfArgs;
printfArgs.reserve(argsRange.size() + 1);
printfArgs.append(argsRange.begin(), argsRange.end());

rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);

rewriter.eraseOp(gpuPrintfOp);

return mlir::success();
}
};

void populateGPUPrintfToSPIRVPatterns(mlir::SPIRVTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns) {

patterns.add<PrintfOpPattern>(typeConverter, patterns.getContext());
}

void GPUXToSPIRVPass::runOnOperation() {
mlir::MLIRContext *context = &getContext();
mlir::ModuleOp module = getOperation();
Expand Down Expand Up @@ -242,6 +362,8 @@ void GPUXToSPIRVPass::runOnOperation() {
mlir::populateSCFToSPIRVPatterns(typeConverter, scfToSpirvCtx, patterns);
mlir::cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
imex::populateGPUPrintfToSPIRVPatterns(typeConverter, patterns);

if (this->enableVCIntrinsic)
imex::populateXeGPUToVCIntrinsicsPatterns(typeConverter, patterns);
else if (this->enableJointMatrix)
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/GPUToSPIRV/gpu-to-llvm.pp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
builtin.module(
imex-convert-gpu-to-spirv{enable-vc-intrinsic=true}
spirv.module(spirv-lower-abi-attrs
spirv-update-vce)
func.func(llvm-request-c-wrappers)
serialize-spirv
convert-gpu-to-gpux
convert-scf-to-cf
convert-cf-to-llvm
convert-arith-to-llvm
convert-func-to-llvm
convert-math-to-llvm
convert-gpux-to-llvm
expand-strided-metadata
lower-affine
finalize-memref-to-llvm
reconcile-unrealized-casts)
55 changes: 55 additions & 0 deletions test/Conversion/GPUToSPIRV/printf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: imex-opt -allow-unregistered-dialect -split-input-file -imex-convert-gpu-to-spirv='enable-vc-intrinsic=true' -verify-diagnostics %s -o - | FileCheck %s

module @test attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, #spirv.resource_limits<>>
} {
func.func @print_test() {
%c1 = arith.constant 1 : index
%c100 = arith.constant 100: i32
%cst_f32 = arith.constant 314.4: f32

gpu.launch_func @kernel_module1::@test_printf_arg
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args(%c100: i32, %cst_f32: f32)
return
}

// CHECK-LABEL: spirv.module @{{.*}} Physical64 OpenCL
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// spirv.SpecConstantComposite
gpu.module @kernel_module0 {
gpu.func @test_printf(%arg0: i32, %arg1: f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = gpu.thread_id x
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]], {{.*}} : (!spirv.ptr<i8, UniformConstant>, ({{.*}})) -> i32
gpu.printf "\nHello\n"
gpu.return
}
}

// CHECK-LABEL: spirv.module @{{.*}} Physical64 OpenCL
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// spirv.SpecConstantComposite
gpu.module @kernel_module1 {
gpu.func @test_printf_arg(%arg0: i32, %arg1: f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = gpu.thread_id x
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]], {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, f32, i64)) -> i32
gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
gpu.return
}
}
}
39 changes: 39 additions & 0 deletions test/Conversion/GPUToSPIRV/printf_with_runner.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module attributes {
gpu.container_module
}{

func.func @main() {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c100 = arith.constant 100: i32
%cst_f32 = arith.constant 314.4: f32

gpu.launch_func @kernel_module::@print_kernel
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args(%c100: i32, %cst_f32: f32)
// CHECK: Hello
// CHECK: Hello, world : 100 314.399994
// CHECK: Thread id: 0
return
}

gpu.module @kernel_module
attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, #spirv.resource_limits<>>
} {
gpu.func @print_kernel(%arg0: i32, %arg1: f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = gpu.thread_id x
gpu.printf "\nHello\n"
gpu.printf "\nHello, world : %d %f\n" %arg0, %arg1: i32, f32
gpu.printf "\nThread id: %d\n" %2: index
gpu.return
}
}
}
Loading

0 comments on commit ad25087

Please sign in to comment.