Skip to content

Commit

Permalink
[CPU] Integrate i8mm patterns from upstream (iree-org#17007)
Browse files Browse the repository at this point in the history
Inserts `populateLowerVectorToArmNeonPatterns()` to trigger smmla arm
neon instructions on compatible tile sizes.

benchmark-extra: android-cpu-dt-only

---------

Co-authored-by: Diego Caballero <[email protected]>
  • Loading branch information
KoolJBlack and dcaballe authored May 8, 2024
1 parent 97afbf4 commit 07d4fe6
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
if (!hasUkernel(target)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
(out.isSignlessInteger(32) || out.isF32())) {
if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) {
return {
TileMxNxK{8, 8, 8}, // Aim to use SMMLA.
TileMxNxK{4, 8, 8}, // Truncation of the above.
TileMxNxK{2, 8, 8}, // Truncation of the above.
TileMxNxK{1, 8, 8}, // Truncation of the above.
};
}

// Default.
return {
TileMxNxK{8, 8, 1}, // Aim to use SMLAL.
TileMxNxK{4, 8, 1}, // Truncation of the above.
Expand All @@ -170,6 +180,15 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
}
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(4) &&
(out.isSignlessInteger(32) || out.isF32())) {
if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) {
return {
TileMxNxK{4, 8, 32},
TileMxNxK{2, 8, 32},
TileMxNxK{1, 8, 32},
};
}

// Default.
return {
TileMxNxK{4, 16, 1}, // Aim to use SMLAL.
TileMxNxK{2, 32, 1}, // Truncation of the above.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:ArithTransforms",
"@llvm-project//mlir:ArmNeon2dToIntr",
"@llvm-project//mlir:ArmNeonDialect",
"@llvm-project//mlir:ArmNeonTransforms",
"@llvm-project//mlir:ArmSMEToLLVM",
"@llvm-project//mlir:ArmSMEToLLVMIRTranslation",
"@llvm-project//mlir:ArmSMEToSCF",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ iree_cc_library(
MLIRArithTransforms
MLIRArmNeon2dToIntr
MLIRArmNeonDialect
MLIRArmNeonTransforms
MLIRArmSMEToLLVM
MLIRArmSMEToLLVMIRTranslation
MLIRArmSMEToSCF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
pipelineOpts.enableUkernels = hasUkernel(target);
pipelineOpts.enableAArch64SSVE =
isAArch64(target) && hasAnySVEFeature(target) && hasSMEFeature(target);
pipelineOpts.enableAArch64I8mm = isAArch64(target) && hasI8mmFeature(target);

IREE::Codegen::TranslationInfoAttr translationInfo =
getTranslationInfo(funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "iree/compiler/Codegen/LLVMCPU/PassDetail.h"
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
Expand All @@ -21,12 +23,15 @@ class LLVMCPUVirtualVectorLoweringPass
LLVMCPUVirtualVectorLoweringPass> {
public:
using LLVMCPUVirtualVectorLoweringBase::LLVMCPUVirtualVectorLoweringBase;
LLVMCPUVirtualVectorLoweringPass(std::string splitVectorTransfersTo) {
LLVMCPUVirtualVectorLoweringPass(std::string splitVectorTransfersTo,
bool enableArmI8mm) {
this->splitVectorTransfersTo = splitVectorTransfersTo;
this->enableArmI8mm = enableArmI8mm;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, vector::VectorDialect>();
registry.insert<linalg::LinalgDialect, vector::VectorDialect,
arm_neon::ArmNeonDialect>();
}
void runOnOperation() override;
};
Expand All @@ -52,30 +57,43 @@ void LLVMCPUVirtualVectorLoweringPass::runOnOperation() {
.setVectorMultiReductionLowering(vectorMultiReductionLowering)
.setVectorTransferSplit(vectorTransferSplit);

RewritePatternSet patterns(ctx);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorGatherLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns, vectorTransformOptions,
/*benefit=*/1,
/*disableOuterProductLowering=*/false);
// This pattern will transform vector loads whose elements are used in a
// scalar fashion into scalar loads. This will let scalar loads to be folded
// into broadcast/arithmetic operations and reduce register pressure.
vector::populateScalarVectorTransferLoweringPatterns(
patterns, /*benefit=*/1, /*allowMultipleUses=*/true);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorMultiReductionLowering);
populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
// Target-dependenet patterns.
{
if (enableArmI8mm) {
RewritePatternSet patterns(ctx);
arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
}

// Target-independent patterns.
{
RewritePatternSet patterns(ctx);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorGatherLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns, vectorTransformOptions,
/*benefit=*/1,
/*disableOuterProductLowering=*/false);
// This pattern will transform vector loads whose elements are used in a
// scalar fashion into scalar loads. This will let scalar loads to be folded
// into broadcast/arithmetic operations and reduce register pressure.
vector::populateScalarVectorTransferLoweringPatterns(
patterns, /*benefit=*/1, /*allowMultipleUses=*/true);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vectorMultiReductionLowering);
populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
}
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUVirtualVectorLoweringPass(std::string splitVectorTransfersTo) {
createLLVMCPUVirtualVectorLoweringPass(std::string splitVectorTransfersTo,
bool enableArmI8mm) {
return std::make_unique<LLVMCPUVirtualVectorLoweringPass>(
splitVectorTransfersTo);
splitVectorTransfersTo, enableArmI8mm);
}

} // namespace mlir::iree_compiler
12 changes: 9 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static llvm::cl::opt<bool> clEnableVectorContractCustomKernels(
"iree-llvmcpu-enable-vector-contract-custom-kernels",
llvm::cl::desc("Enables vector contract custom kernels for "
"LLVMCPUMmt4dVectorLowering pass."),
llvm::cl::init(true));
llvm::cl::init(false));

static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createTileAndDistributeToWorkgroupsPass());
Expand Down Expand Up @@ -288,8 +288,8 @@ void buildLLVMCPUVectorLoweringPipeline(
OpPassManager &funcPassManager,
const LLVMCPUVectorLoweringPassOptions &options) {
funcPassManager.addPass(createLLVMCPUDropVectorUnitDimsPass());
funcPassManager.addPass(
createLLVMCPUVirtualVectorLoweringPass(options.splitVectorTransfersTo));
funcPassManager.addPass(createLLVMCPUVirtualVectorLoweringPass(
options.splitVectorTransfersTo, options.enableArmI8mm));

// Make sure we remove redundant vector ops (e.g., vector tranposes) before we
// lower them and can't be optimized away anymore.
Expand Down Expand Up @@ -338,6 +338,7 @@ void addCPUBufferOpsTileAndVectorizePipeline(
LLVMCPUVectorLoweringPassOptions options;
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}

Expand Down Expand Up @@ -414,6 +415,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
LLVMCPUVectorLoweringPassOptions options;
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}

Expand Down Expand Up @@ -471,6 +473,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
LLVMCPUVectorLoweringPassOptions options;
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "shuffle";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}

Expand Down Expand Up @@ -546,6 +549,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
LLVMCPUVectorLoweringPassOptions options;
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}

Expand Down Expand Up @@ -574,6 +578,7 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
LLVMCPUVectorLoweringPassOptions options;
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -606,6 +611,7 @@ void addCPULinalgExtTileAndVectorizePipeline(
LLVMCPUVectorLoweringPassOptions options;
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,15 @@ createLLVMCPUUnfuseFMAOpsPass();
struct LLVMCPUVectorLoweringPassOptions {
std::string splitVectorTransfersTo = "";
bool lowerVectorTransposeToAVX2 = false;
bool enableArmI8mm = false;
};

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUDropVectorUnitDimsPass();

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUVirtualVectorLoweringPass(std::string splitVectorTransfersTo = "");
createLLVMCPUVirtualVectorLoweringPass(std::string splitVectorTransfersTo = "",
bool enableArmI8mm = false);

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createLLVMCPUVectorTransferLoweringPass();
Expand Down Expand Up @@ -140,6 +142,7 @@ struct LLVMCPUPipelineOptions {
bool enablePeeling = false;
bool enableVectorMasking = false;
bool enableAArch64SSVE = false;
bool enableAArch64I8mm = false;
bool enableUkernels = false;
bool lowerToAVX2 = false;
};
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def LLVMCPUVirtualVectorLowering :
"\tlinalg-copy: use linalg.fill + linalg.generic for the slow path\n"
"\tvector-transfers: use extra small unmasked vector.transfers for"
" the slow path\n}]>,
Option<"enableArmI8mm", "enable-arm-i8mm", "bool",
/*default=*/ "false",
"Enables arm i8mm lowering patterns">,
];
let constructor =
"mlir::iree_compiler::createLLVMCPUVirtualVectorLoweringPass()";
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ bool hasSMEFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasFeature(targetAttr, "+sme");
}

bool hasI8mmFeature(IREE::HAL::ExecutableTargetAttr targetAttr) {
return hasFeature(targetAttr, "+i8mm");
}

} // namespace mlir::iree_compiler
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ bool hasAnySVEFeature(IREE::HAL::ExecutableTargetAttr targetAttr);
/// Returns true if the 'targetAttr' contains '+sme' in its cpu features.
bool hasSMEFeature(IREE::HAL::ExecutableTargetAttr targetAttr);

/// Returns true if the 'targetAttr' contains '+i8mm' in its cpu features.
bool hasI8mmFeature(IREE::HAL::ExecutableTargetAttr targetAttr);

} // namespace mlir::iree_compiler

#endif // IREE_COMPILER_CODEGEN_LLVMCPU_UTILS_H_

0 comments on commit 07d4fe6

Please sign in to comment.