Skip to content

Commit

Permalink
[ArithToVC][Conversion] Add arith-to-vc pass. (#964)
Browse files Browse the repository at this point in the history
* [ArithToVC][Conversion] Add arith-to-vc pass.
Add a standalone conversion pass for select arithmetic ops to vc-intrinsics.
IGC does not support some of the arithmetic SPIR-V operations for
larger vector lengths (>16).
Therefore, arith dialect ops that lowers to those SPIR-V ops
are converted to vc-intrinsics for larger vector lengths (>16).
  • Loading branch information
mshahneo authored Nov 16, 2024
1 parent aa75f70 commit 4844752
Show file tree
Hide file tree
Showing 17 changed files with 288 additions and 92 deletions.
51 changes: 51 additions & 0 deletions include/imex/Conversion/ArithToVC/ArithToVC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===- ArithToVC.h - Conversion---------------*- C++ -*-===//
//
// Copyright 2024 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements conversion of the select arith dialect operations into
/// Func dialect calls to vc-intrinsics functions
///
//===----------------------------------------------------------------------===//
#ifndef IMEX_CONVERSION_ARITHTOVC_H
#define IMEX_CONVERSION_ARITHTOVC_H

#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Vector/IR/VectorOps.h>

#include "imex/Utils/XeCommon.h"

namespace mlir {

class ConversionTarget;
class LLVMTypeConverter;
class Pass;
class Operation;
class RewritePatternSet;
template <typename T> class OperationPass;

namespace gpu {
class GPUModuleOp;
} // namespace gpu

} // namespace mlir

namespace imex {
#define GEN_PASS_DECL_CONVERTARITHTOVC
#include "imex/Conversion/Passes.h.inc"

void populateArithToVCPatterns(
::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewritePatternSet &patterns,
bool enableHighPrecisionInterimCalculation = false);
void configureArithToVCConversionLegality(::mlir::ConversionTarget &target);
std::unique_ptr<::mlir::OperationPass<::mlir::gpu::GPUModuleOp>>
createConvertArithToVCPass();

} // namespace imex
#endif
5 changes: 0 additions & 5 deletions include/imex/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,3 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion)
add_public_tablegen_target(IMEXConversionPassIncGen)

add_mlir_doc(Passes IMEXConversionPasses ./ -gen-pass-doc)
add_subdirectory(DistToStandard)
add_subdirectory(DropRegions)
add_subdirectory(XeTileToXeGPU)
add_subdirectory(XeGPUToVC)
add_subdirectory(MathToVC)
Empty file.
Empty file.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions include/imex/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mlir/Pass/Pass.h"

#include <imex/Conversion/ArithToVC/ArithToVC.h>
#include <imex/Conversion/DistToStandard/DistToStandard.h>
#include <imex/Conversion/DropRegions/DropRegions.h>
#include <imex/Conversion/GPUToGPUX/GPUToGPUX.h>
Expand Down
31 changes: 31 additions & 0 deletions include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,35 @@ def ConvertMathToVC : Pass<"convert-math-to-vc", "::mlir::gpu::GPUModuleOp"> {
let constructor = "imex::createConvertMathToVCPass()";
}


//===----------------------------------------------------------------------===//
// ArithToVC
//===----------------------------------------------------------------------===//
// high-precision-interim-calculation
def ConvertArithToVC : Pass<"convert-arith-to-vc", "::mlir::gpu::GPUModuleOp"> {
let summary = "Generate vc-intrinsics functions for select arith dialect operations";
let description = [{
Convert select arith dialect operations into the Func dialect calls to vc-intrinsics
functions.
Some arith operations SPIR-V counterpart are not supported by the VC compiler
(IGC vector backend) or not performant enough
and need to be converted to vc-intrinsic calls.
This pass converts these arith operations to vc-intrinsics.
}];
let options = [
Option<"enableHighPrecisionInterimCalculation", "enable-high-precision-interim-calculation", "bool",
/*default=*/"false",
"Enables high precision (f32) interim calculation for arith operations."
"For any interim instruction added as part of the conversion will be high precision(f32).">
];

let dependentDialects = ["::mlir::arith::ArithDialect",
"::mlir::vector::VectorDialect",
"::mlir::LLVM::LLVMDialect",
"::mlir::func::FuncDialect"
];
let constructor = "imex::createConvertArithToVCPass()";
}


#endif // _IMEX_CONVERSION_PASSES_TD_INCLUDED_
Empty file.
Empty file.
161 changes: 161 additions & 0 deletions lib/Conversion/ArithToVC/ArithToVC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//===- ArithToVC.cpp - Conversion---------------*- C++ -*-===//
//
// Copyright 2024 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements conversion of the select arith dialect operations into
/// Func dialect calls to vc-intrinsics functions
///
//===----------------------------------------------------------------------===//

#include "imex/Conversion/ArithToVC/ArithToVC.h"
#include "imex/Utils/VCUtils.h"
#include "imex/Utils/XeCommon.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"

namespace imex {
#define GEN_PASS_DEF_CONVERTARITHTOVC
#include "imex/Conversion/Passes.h.inc"
} // namespace imex

using namespace mlir;
using namespace imex;

namespace {

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

// Get the VC intrinsic name for the given arith operation
template <typename MOp> std::string getVCIntrinsicName() {
constexpr bool isFMaxOp = std::is_same_v<MOp, arith::MaximumFOp>;
if (isFMaxOp)
return "llvm.genx.fmax.";
else
assert(0 && "Unsupported arith Op. Add more support!");
}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

// Elementwise arith to vc-intrinsics conversion pattern for ops that only
// supports f32
template <typename MOp>
struct ElementwiseArithOpPattern final : public OpConversionPattern<MOp> {
using OpConversionPattern<MOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(MOp op, typename MOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if the result type is a 1D vector
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
return failure();
if (vecTy.getRank() != 1)
return failure();

auto loc = op.getLoc();
auto args = adaptor.getOperands();

bool isVectorAnyINTELType = imex::isVectorAnyINTELType(op.getType());
bool isFastmath =
(op.getFastmathAttr().getValue() != arith::FastMathFlags::none);
if (!isVectorAnyINTELType && !isFastmath)
return failure();
// for large vectors, generate the corresponding VC intrinsic.
auto funcName = getVCIntrinsicName<MOp>();
funcName += encodeVectorType(rewriter, vecTy).first;
auto callOp =
createFuncCall(rewriter, loc, funcName, {op.getType()}, args, false);
rewriter.replaceOp(op, callOp);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//

void imex::populateArithToVCPatterns(
::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewritePatternSet &patterns,
bool enableHighPrecisionInterimCalculation) {
// Add patterns
patterns.add<ElementwiseArithOpPattern<arith::MaximumFOp>>(
patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Conversion Legality configuration
//===----------------------------------------------------------------------===//
void imex::configureArithToVCConversionLegality(
::mlir::ConversionTarget &target) {
// Add legal dialects
target.addLegalDialect<func::FuncDialect, arith::ArithDialect>();
// arith.maximumf is only converted if they are 1D vectors
target.addDynamicallyLegalOp<arith::MaximumFOp>([&](arith::MaximumFOp op) {
if (auto vecTy = dyn_cast<VectorType>(op.getType())) {
if (vecTy.getRank() != 1)
return true;
bool isVectorAnyINTELType = imex::isVectorAnyINTELType(op.getType());
bool isFastmath =
(op.getFastmathAttr().getValue() != arith::FastMathFlags::none);
if (!isVectorAnyINTELType && !isFastmath)
return true;
return false;
}
return true;
});
}

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
struct ArithToVCPass : public imex::impl::ConvertArithToVCBase<ArithToVCPass> {
using Base::Base;
ArithToVCPass(bool enableHPIC)
: imex::impl::ConvertArithToVCBase<ArithToVCPass>() {
this->enableHighPrecisionInterimCalculation.setValue(enableHPIC);
}
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
LLVMTypeConverter typeConverter(&getContext());
ConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());

// Add patterns
imex::populateArithToVCPatterns(
typeConverter, patterns,
this->enableHighPrecisionInterimCalculation.getValue());
configureArithToVCConversionLegality(target);

if (failed(applyPartialConversion(m, target, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
imex::createConvertArithToVCPass() {
return std::make_unique<ArithToVCPass>();
}
21 changes: 21 additions & 0 deletions lib/Conversion/ArithToVC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
add_imex_conversion_library(IMEXArithToVC
ArithToVC.cpp


ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToVC

DEPENDS
IMEXConversionPassIncGen

#LINK_COMPONENTS

LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
# MLIRTransforms
MLIRLLVMCommonConversion

MLIRGPUDialect
MLIRPass
)
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(ArithToVC)
add_subdirectory(DistToStandard)
add_subdirectory(NDArrayToLinalg)
add_subdirectory(DropRegions)
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/MathToVC/MathToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ void imex::configureMathToVCConversionLegality(
namespace {
struct MathToVCPass : public imex::impl::ConvertMathToVCBase<MathToVCPass> {
using Base::Base;
MathToVCPass(bool emitDeallocs)
MathToVCPass(bool enableHPIC)
: imex::impl::ConvertMathToVCBase<MathToVCPass>() {
this->enableHighPrecisionInterimCalculation.setValue(emitDeallocs);
this->enableHighPrecisionInterimCalculation.setValue(enableHPIC);
}
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/XeGPUToVC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_imex_conversion_library(IMEXXeGPUToVC
#LINK_COMPONENTS

LINK_LIBS PUBLIC
IMEXArithToVC
IMEXMathToVC
MLIRIR
MLIRSupport
Expand Down
Loading

0 comments on commit 4844752

Please sign in to comment.