Skip to content

Commit

Permalink
Use arith/math ops instead of spirv.CL.* ops in tests. (#702)
Browse files Browse the repository at this point in the history
charithaintc authored Mar 25, 2024
1 parent b081001 commit 385bcd2
Showing 6 changed files with 47 additions and 26 deletions.
5 changes: 4 additions & 1 deletion include/imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
#ifndef IMEX_CONVERSION_XEGPUTOSPIRV_H
#define IMEX_CONVERSION_XEGPUTOSPIRV_H

#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include <mlir/Dialect/SPIRV/IR/SPIRVDialect.h>
#include <mlir/Dialect/SPIRV/IR/SPIRVOps.h>
#include <mlir/Transforms/DialectConversion.h>
@@ -25,7 +26,9 @@ class Pass;

namespace imex {

template <typename SPIRVOp> std::string getVCIntrinsicName(SPIRVOp op);
// helper to check the legal vector lengths for arith/math ops
bool isGenericVectorTy(mlir::Type type);

// XeGPU to VC Intrinsics pattern
void populateXeGPUToVCIntrinsicsPatterns(
mlir::SPIRVTypeConverter &typeConverter, mlir::RewritePatternSet &patterns);
17 changes: 17 additions & 0 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
@@ -211,6 +211,8 @@ void GPUXToSPIRVPass::runOnOperation() {
mlir::RewritePatternSet patterns(context);
mlir::SPIRVConversionOptions options;
options.use64bitIndex = true;
// FIXME: activate fast math per operator basis.
options.enableFastMathMode = true;

mlir::SPIRVTypeConverter typeConverter(targetAttr, options);

@@ -360,9 +362,24 @@ void GPUXToSPIRVPass::runOnOperation() {
});
}

// SPIR-V elementwise arith/math ops require special handling if the operate
// on large vectors. We dynamically legalize these ops based on the vector
// size they consume.
// FIXME: this is not an exhaustive list of arith/math ops that need special
// handling.
target->addDynamicallyLegalOp<mlir::spirv::CLExpOp>(
[&](mlir::spirv::CLExpOp op) {
return imex::isGenericVectorTy(op.getType());
});
target->addDynamicallyLegalOp<mlir::spirv::CLFMaxOp>(
[&](mlir::spirv::CLFMaxOp op) {
return imex::isGenericVectorTy(op.getType());
});

//------- Upstream Conversion------------
mlir::populateGPUToSPIRVPatterns(typeConverter, patterns);
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
mlir::populateMathToSPIRVPatterns(typeConverter, patterns);
mlir::populateMemRefToSPIRVPatterns(typeConverter, patterns);
mlir::populateFuncToSPIRVPatterns(typeConverter, patterns);
// ---------------------------------------
30 changes: 15 additions & 15 deletions lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp
Original file line number Diff line number Diff line change
@@ -363,7 +363,7 @@ class UpdateNDOffsetToVCPattern : public OpConversionPattern<UpdateNDOffsetOp> {
auto desc = adaptor.getTensorDesc();
for (size_t i = 0; i < offsets.size(); i++) {
auto offset = offsets[i];
if (auto cst = dyn_cast<spirv::ConstantOp>(offset.getDefiningOp()))
if (auto cst = offset.getDefiningOp<arith::ConstantOp>())
if (auto attr = dyn_cast<mlir::IntegerAttr>(cst.getValue());
attr && attr.getInt() == 0)
continue;
@@ -1538,21 +1538,13 @@ struct SPIRVElementwiseToVC : public OpConversionPattern<SPIRVOp> {
if (!dstType)
return failure();

// This lowering pattern is needed only for spirv ops with large vector
// lengths.
assert(
!imex::isGenericVectorTy(dstType) &&
"Vector size is considered generic and op does not require lowering to "
"VC intrinsic. Consider marking this op + vector length as legal.");
auto vecSize = dstType.template dyn_cast<VectorType>().getNumElements();
auto hasGenericVecSize = [&]() -> bool {
// if the input is scalar, we keep the operation as is.
if (isa<spirv::ScalarType>(dstType))
return true;
// or, if the vector size is 2, 3, 4, 8, or 16, we keep the operation.
return vecSize == 2 || vecSize == 3 || vecSize == 4 || vecSize == 8 ||
vecSize == 16;
};

if (hasGenericVecSize()) {
rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
return success();
}

// for larger vector lengths, "llvm.genx.exp" returns the base 2
// exponentiation of the input. To get the base e exponentiation, we need to
// scale the input by log2(e)
@@ -1588,6 +1580,14 @@ struct SPIRVElementwiseToVC : public OpConversionPattern<SPIRVOp> {
};
} // namespace

bool imex::isGenericVectorTy(mlir::Type type) {
if (isa<spirv::ScalarType>(type))
return true;
auto vecSize = type.dyn_cast<VectorType>().getNumElements();
return vecSize == 2 || vecSize == 3 || vecSize == 4 || vecSize == 8 ||
vecSize == 16;
}

void imex::populateXeGPUToVCIntrinsicsPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<CreateNdDescToSPIRV, CreateDescToVCPattern, DpasToVCPattern,
18 changes: 9 additions & 9 deletions test/Integration/Dialect/XeGPU/exp_f32.vc.mlir
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ module @gemm attributes {gpu.container_module} {
// do DPAS
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
// take exp
%t6 = spirv.CL.exp %val4 : vector<8x16xf32>
%t6 = math.exp %val4 : vector<8x16xf32>
// store
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %t6, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -65,14 +65,14 @@ module @gemm attributes {gpu.container_module} {
%v6 = vector.extract %val4[6] : vector<16xf32> from vector<8x16xf32>
%v7 = vector.extract %val4[7] : vector<16xf32> from vector<8x16xf32>
// do generic size exp
%v0_exp = spirv.CL.exp %v0 : vector<16xf32>
%v1_exp = spirv.CL.exp %v1 : vector<16xf32>
%v2_exp = spirv.CL.exp %v2 : vector<16xf32>
%v3_exp = spirv.CL.exp %v3 : vector<16xf32>
%v4_exp = spirv.CL.exp %v4 : vector<16xf32>
%v5_exp = spirv.CL.exp %v5 : vector<16xf32>
%v6_exp = spirv.CL.exp %v6 : vector<16xf32>
%v7_exp = spirv.CL.exp %v7 : vector<16xf32>
%v0_exp = math.exp %v0 : vector<16xf32>
%v1_exp = math.exp %v1 : vector<16xf32>
%v2_exp = math.exp %v2 : vector<16xf32>
%v3_exp = math.exp %v3 : vector<16xf32>
%v4_exp = math.exp %v4 : vector<16xf32>
%v5_exp = math.exp %v5 : vector<16xf32>
%v6_exp = math.exp %v6 : vector<16xf32>
%v7_exp = math.exp %v7 : vector<16xf32>
%v0_exp_cast = vector.shape_cast %v0_exp : vector<16xf32> to vector<1x16xf32>
%v1_exp_cast = vector.shape_cast %v1_exp : vector<16xf32> to vector<1x16xf32>
%v2_exp_cast = vector.shape_cast %v2_exp : vector<16xf32> to vector<1x16xf32>
2 changes: 1 addition & 1 deletion test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ module @gemm attributes {gpu.container_module} {
%val4 = xegpu.dpas %val0, %val2 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%val5 = xegpu.dpas %val1, %val3 : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
// take fmax
%val6 = spirv.CL.fmax %val4, %val5 : vector<8x16xf32>
%val6 = arith.maximumf %val4, %val5 fastmath<nnan> : vector<8x16xf32>
// store fmax
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] { mode = vc } : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %val6, %out_tile { mode = vc} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
1 change: 1 addition & 0 deletions test/Integration/Dialect/XeGPU/xegpu-to-llvm.pp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// linalg dialect to gpu dialect lowering pipeline
// Ready for vulkan runner or narrow scope l0/sycl runner starting from GPU dialect.
builtin.module(
imex-vector-linearize
imex-convert-gpu-to-spirv{enable-vc-intrinsic=true}
spirv.module(spirv-lower-abi-attrs
spirv-update-vce)

0 comments on commit 385bcd2

Please sign in to comment.