Skip to content

Commit

Permalink
[aievec] Use intrinsics conversion to LLVM Dialect for matmul (Xilinx…
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain authored Feb 23, 2024
1 parent 54b8bc1 commit 58c0dc8
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 170 deletions.
12 changes: 12 additions & 0 deletions include/aie/Dialect/AIEVec/IR/AIEVecLLVMIntrOp.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class ExtIntrOpBase<Dialect dialect,
list<int> llvmArgIndices = [];
}

// TODO: Create an aievecllvm dialect so it can be marked legal all at once.
// TODO: That will require moving AIEVecLLVMIntrOp.td _out_ of AIEVecOps.td
// TODO: header, which is how these are being generated.

// For AIE2 only
class AIEVec2_IntrOp<string mnemonic,
list<Trait> traits = [],
Expand All @@ -46,6 +50,14 @@ def MacConfAcc32IntrOp :
VectorOfLengthAndType<[16], [I64]>:$acc,
I32:$conf)>;

def MacConfAcc64IntrOp :
AIEVec2_IntrOp<"I512.I512.ACC1024.acc64.mac.conf",
[TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>,
Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs,
VectorOfLengthAndType<[16], [I32]>:$rhs,
VectorOfLengthAndType<[16], [I64]>:$acc,
I32:$conf)>;

def MacConfBF16IntrOp :
AIEVec2_IntrOp<"bf.mac16.conf",
[TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>,
Expand Down
268 changes: 125 additions & 143 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// (c) Copyright 2022 Xilinx Inc.
// (c) Copyright 2024 Advanced Micro Devices Inc.
//
//===----------------------------------------------------------------------===//

Expand All @@ -13,19 +14,100 @@
#include "aie/Conversion/AIEVecToLLVM/AIEVecToLLVM.h"
#include "aie/Dialect/AIEVec/AIEVecUtils.h"
#include "aie/Dialect/AIEVec/IR/AIEVecOps.h"

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"

#include <numeric>
#include <sstream>

using namespace mlir;

namespace xilinx::aievec {

inline static Value bitcastValueToType(OpBuilder &builder, Location loc,
Value val, Type dstTy) {
return builder.create<LLVM::BitcastOp>(loc, dstTy, val).getResult();
}

// This function emits the instructions required to widen a 128b input vector
// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
// vector<4xi32> to respect the intrinsic signature.
inline static Value widen128bVectorValueTo512b(OpBuilder &builder, Location loc,
Value val) {
return builder
.create<aievec::VectorSetI512I128IntrOp>(
loc, VectorType::get({16}, builder.getI32Type()),
bitcastValueToType(builder, loc, val,
VectorType::get({4}, builder.getI32Type())))
.getResult();
}

// This function emits the instructions required to widen a 256b input vector
// into a 512b encoded as a vector<16xi32>. It first bitcasts it to a
// vector<8xi32> to respect the intrinsic signature. It will also materialize
// a constant 0, used as an insertion index.
inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc,
Value val) {
auto cst0 =
builder.create<LLVM::ConstantOp>(loc, builder.getI32Type(), (int32_t)0);
return builder
.create<aievec::VectorSetI512I256IntrOp>(
loc, VectorType::get({16}, builder.getI32Type()),
bitcastValueToType(builder, loc, val,
VectorType::get({8}, builder.getI32Type())),
cst0)
.getResult();
}

// This function emits the sequence of operations that forces a value into a
// specific type. This may include widening vectors to match a specific bit
// length.
static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
Type type) {
auto valTy = val.getType();
if (valTy == type)
return val;
auto srcVecTy = dyn_cast<VectorType>(valTy);
if (srcVecTy) {
auto dstVecTy = dyn_cast<VectorType>(type);
assert(dstVecTy && "vector values cannot be forced into a non-vector type");
assert(srcVecTy.getRank() == 1 && dstVecTy.getRank() == 1 &&
"only flat 1D vectors can be force casted");
int64_t dstVecLength =
dstVecTy.getElementTypeBitWidth() * dstVecTy.getShape()[0];
int64_t srcVecLength =
srcVecTy.getElementTypeBitWidth() * srcVecTy.getShape()[0];
if (srcVecLength != dstVecLength) {
assert(srcVecLength < dstVecLength &&
"only widening forced casts are supported");
assert(dstVecLength == 512 &&
(srcVecLength == 128 || srcVecLength == 256) &&
"only 128b to 512b and 256b to 512b forced casts are supported");
if (srcVecLength == 128)
val = widen128bVectorValueTo512b(builder, loc, val);
else
val = widen256bVectorValueTo512b(builder, loc, val);
}
}
return bitcastValueToType(builder, loc, val, type);
}

// This function emits the sequence of operations that forces a range of values
// to match the signature specified by the TypeRange. It can be used to convert
// the parameters of an op being converted to the types accepted by an
// intrinsic with a fixed signature that treats its inputs as "bags of bits".
static SmallVector<Value> forceCastOperandsToSignature(OpBuilder &builder,
Location loc,
ValueRange operands,
TypeRange signature) {
return llvm::to_vector(llvm::map_range(
llvm::zip_equal(operands, signature), [&](auto &&vt) -> Value {
return forceCastValueToType(builder, loc, std::get<0>(vt),
std::get<1>(vt));
}));
}

struct BufferParams {
uint32_t start;
uint32_t offsets;
Expand Down Expand Up @@ -825,82 +907,6 @@ class MatMulOpConversion
vecTy.getElementType());
}

static LLVM::LLVMFuncOp getVectorSetFunction(PatternRewriter &b,
ModuleOp moduleOp,
int64_t fromBitWidth,
int64_t toBitWidth,
bool isAcc = false) {
assert((isAcc && ((fromBitWidth == 256 || fromBitWidth == 512) &&
toBitWidth == 1024) ||
!isAcc && ((fromBitWidth == 128 || fromBitWidth == 256) &&
toBitWidth == 512)) &&
"invalid vector set function");
std::stringstream intrNameStream;
intrNameStream << "llvm.aie2.set.";
if (!isAcc)
intrNameStream << "I";
intrNameStream << toBitWidth << ".";
if (!isAcc)
intrNameStream << "I";
intrNameStream << fromBitWidth;
if (isAcc)
intrNameStream << ".acc";
std::string intrinsicName = intrNameStream.str();
MLIRContext *ctx = b.getContext();
auto funcOp = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(
StringAttr::get(ctx, intrinsicName));
if (!funcOp) {
int64_t elemBitWidth = 32;
Type elemTy = b.getI32Type();
if (isAcc) {
elemBitWidth = 64;
elemTy = b.getI64Type();
}
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallVector<Type, 2> funcSigParamTy(
{VectorType::get({fromBitWidth / elemBitWidth}, elemTy)});
if (fromBitWidth != 128)
funcSigParamTy.push_back(b.getI32Type());
funcOp = b.create<LLVM::LLVMFuncOp>(
b.getUnknownLoc(), intrinsicName,
LLVM::LLVMFunctionType::get(
VectorType::get({toBitWidth / elemBitWidth}, elemTy),
funcSigParamTy));
}
return funcOp;
}

static Value widenVectorTo512bit(PatternRewriter &b, Location loc,
Value val) {
auto valTy = cast<VectorType>(val.getType());
auto elemBitWidth = valTy.getElementTypeBitWidth();
int64_t vecBitWidth = elemBitWidth * valTy.getShape()[0];
if (vecBitWidth == 512)
return val;
if (vecBitWidth == 128)
val = b.create<LLVM::BitcastOp>(loc, VectorType::get({4}, b.getI32Type()),
val)
.getResult();
else if (vecBitWidth == 256)
val = b.create<LLVM::BitcastOp>(loc, VectorType::get({8}, b.getI32Type()),
val)
.getResult();
else
llvm_unreachable("invalid vector type");

auto moduleOp = val.getParentRegion()->getParentOfType<ModuleOp>();
auto funcOp = getVectorSetFunction(b, moduleOp, vecBitWidth, 512);
SmallVector<Value, 2> operands({val});
if (vecBitWidth == 256) {
auto zeroi32 = b.create<arith::ConstantOp>(loc, b.getI32Type(),
b.getI32IntegerAttr(0))
.getResult();
operands.push_back(zeroi32);
}
return b.create<LLVM::CallOp>(loc, funcOp, operands).getResult();
}

LogicalResult
matchAndRewrite(aievec::MatMulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Expand All @@ -921,76 +927,48 @@ class MatMulOpConversion
decodedMatMulOp.acc = rewriter.create<vector::ShapeCastOp>(
loc, accFlattenedVecTy, decodedMatMulOp.acc);

if (decodedMatMulOp.kind != DecodedMatMulOp::Kind::BF16) {
if (lhsFlattenedVecTy.getShape()[0] *
lhsFlattenedVecTy.getElementTypeBitWidth() !=
512)
decodedMatMulOp.lhs =
widenVectorTo512bit(rewriter, loc, decodedMatMulOp.lhs);
decodedMatMulOp.lhs =
rewriter
.create<LLVM::BitcastOp>(
loc, VectorType::get({64}, rewriter.getI8Type()),
decodedMatMulOp.lhs)
.getResult();
if (rhsFlattenedVecTy.getShape()[0] *
rhsFlattenedVecTy.getElementTypeBitWidth() !=
512)
decodedMatMulOp.rhs =
widenVectorTo512bit(rewriter, loc, decodedMatMulOp.rhs);
decodedMatMulOp.rhs =
rewriter
.create<LLVM::BitcastOp>(
loc, VectorType::get({16}, rewriter.getI32Type()),
decodedMatMulOp.rhs)
.getResult();
decodedMatMulOp.acc =
rewriter
.create<LLVM::BitcastOp>(
loc, VectorType::get({16}, rewriter.getI64Type()),
decodedMatMulOp.acc)
.getResult();
} else
decodedMatMulOp.acc =
Type i32ty = rewriter.getI32Type();
auto confCst = rewriter.create<LLVM::ConstantOp>(
loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
SmallVector<Value> operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs,
decodedMatMulOp.acc, confCst});
Value matMulResVal;
if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::BF16)
matMulResVal =
rewriter
.create<LLVM::BitcastOp>(
.create<aievec::MacConfBF16IntrOp>(
loc, VectorType::get({8}, rewriter.getI64Type()),
decodedMatMulOp.acc)
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({32}, rewriter.getBF16Type()),
VectorType::get({8}, rewriter.getI64Type()), i32ty}))
.getResult();
std::string intrinsicName;
if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32)
intrinsicName = "llvm.aie2.I512.I512.ACC1024.acc32.mac.conf";
else if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I64)
intrinsicName = "llvm.aie2.I512.I512.ACC1024.acc64.mac.conf";
else
intrinsicName = "llvm.aie2.bf.mac16.conf";

// If the intrinsic declaration doesn't exist, create it
auto module = op->getParentOfType<ModuleOp>();
MLIRContext *context = rewriter.getContext();
auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(
StringAttr::get(context, intrinsicName));

auto i32ty = rewriter.getI32Type();
if (!func) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
func = rewriter.create<LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), intrinsicName,
LLVM::LLVMFunctionType::get(decodedMatMulOp.acc.getType(),
{decodedMatMulOp.lhs.getType(),
decodedMatMulOp.rhs.getType(),
decodedMatMulOp.acc.getType(), i32ty}));
else {
SmallVector<Type> intrFuncSig(
{VectorType::get({64}, rewriter.getI8Type()),
VectorType::get({16}, i32ty),
VectorType::get({16}, rewriter.getI64Type()), i32ty});
VectorType v16xi64ty = VectorType::get({16}, rewriter.getI64Type());
if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32)
matMulResVal = rewriter
.create<aievec::MacConfAcc32IntrOp>(
loc, v16xi64ty,
forceCastOperandsToSignature(
rewriter, loc, operands, intrFuncSig))
.getResult();
else
matMulResVal = rewriter
.create<aievec::MacConfAcc64IntrOp>(
loc, v16xi64ty,
forceCastOperandsToSignature(
rewriter, loc, operands, intrFuncSig))
.getResult();
}

auto confCst = rewriter.create<arith::ConstantOp>(
loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf));
auto callOp = rewriter.create<LLVM::CallOp>(
loc, func,
ValueRange{decodedMatMulOp.lhs, decodedMatMulOp.rhs,
decodedMatMulOp.acc, confCst});
auto castFromAcc = rewriter.create<LLVM::BitcastOp>(loc, accFlattenedVecTy,
callOp.getResult());
auto castFromAcc =
bitcastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy);

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
castFromAcc);

Expand Down Expand Up @@ -1035,6 +1013,10 @@ struct ConvertAIEVecToLLVMPass
LLVMConversionTarget target(getContext());
target.addIllegalDialect<AIEVecDialect>();
target.addLegalDialect<arith::ArithDialect, vector::VectorDialect>();
target
.addLegalOp<aievec::MacConfAcc32IntrOp, aievec::MacConfAcc64IntrOp,
aievec::MacConfBF16IntrOp, aievec::VectorSetI512I128IntrOp,
aievec::VectorSetI512I256IntrOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
Expand Down
Loading

0 comments on commit 58c0dc8

Please sign in to comment.