Skip to content

Commit 89642ac

Browse files
author
chen.qian
committed
[Pass] fix sqrt add to or bug and getFirst/LastInst refactor
1 parent e116ff0 commit 89642ac

File tree

1 file changed

+89
-75
lines changed

1 file changed

+89
-75
lines changed

llvm/lib/Target/RISCV/RISCVLoopUnrollAndRemainder.cpp

+89-75
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
#include "llvm/Support/FileSystem.h"
8686
#include "llvm/Support/raw_ostream.h"
8787
#include "llvm/Transforms/IPO.h"
88+
#include "llvm/Transforms/InstCombine/InstCombine.h"
8889
#include "llvm/Transforms/Scalar.h"
8990
#include "llvm/Transforms/Scalar/DCE.h"
9091
#include "llvm/Transforms/Scalar/DeadStoreElimination.h"
@@ -171,21 +172,19 @@ static ICmpInst *getLastICmpInstWithPredicate(BasicBlock *BB,
171172
return lastICmp;
172173
}
173174

174-
// Helper function to get the first ICmp instruction in a basic block
175-
static ICmpInst *getFirstICmpInst(BasicBlock *BB) {
175+
template <typename T> static T *getFirstInst(BasicBlock *BB) {
176176
for (Instruction &I : *BB) {
177-
if (auto *CI = dyn_cast<ICmpInst>(&I)) {
178-
return CI;
177+
if (T *Inst = dyn_cast<T>(&I)) {
178+
return Inst;
179179
}
180180
}
181181
return nullptr;
182182
}
183183

184-
// Helper function to get the last ICmp instruction in a basic block
185-
static ICmpInst *getLastICmpInst(BasicBlock *BB) {
186-
for (auto it = BB->rbegin(); it != BB->rend(); ++it) {
187-
if (auto *icmp = dyn_cast<ICmpInst>(&*it)) {
188-
return icmp;
184+
template <typename T> static T *getLastInst(BasicBlock *BB) {
185+
for (Instruction &I : reverse(*BB)) {
186+
if (T *Inst = dyn_cast<T>(&I)) {
187+
return Inst;
189188
}
190189
}
191190
return nullptr;
@@ -239,16 +238,6 @@ static PHINode *getLastI32Phi(BasicBlock *BB) {
239238
return nullptr;
240239
}
241240

242-
// Helper function to get the last PHI node in a basic block
243-
static PHINode *getLastPhi(BasicBlock *BB) {
244-
for (auto it = BB->rbegin(); it != BB->rend(); ++it) {
245-
if (auto *Phi = dyn_cast<PHINode>(&*it)) {
246-
return Phi;
247-
}
248-
}
249-
return nullptr;
250-
}
251-
252241
// Helper function to get the first CallInst with a specific name in a basic
253242
// block
254243
static CallInst *getFirstCallInstWithName(BasicBlock *BB, StringRef Name) {
@@ -406,6 +395,38 @@ static void movePHINodesToTop(BasicBlock &BB,
406395
}
407396
}
408397

398+
static void modifyFirdAddToOr(BasicBlock *ClonedForBody) {
399+
SmallVector<BinaryOperator *> addInsts;
400+
401+
// Collect all add instructions that meet the criteria
402+
for (auto &I : *ClonedForBody) {
403+
if (auto *binOp = dyn_cast<BinaryOperator>(&I)) {
404+
if (binOp->getOpcode() == Instruction::Add && binOp->hasNoSignedWrap() &&
405+
binOp->hasNoUnsignedWrap()) {
406+
addInsts.push_back(binOp);
407+
}
408+
}
409+
}
410+
if (addInsts.empty()) {
411+
return;
412+
}
413+
// Replace each add instruction with an or disjoint instruction
414+
for (auto it = addInsts.begin(); it != std::prev(addInsts.end()); ++it) {
415+
auto *addInst = *it;
416+
// Create a new or disjoint instruction
417+
Instruction *orInst =
418+
BinaryOperator::CreateDisjoint(Instruction::Or, addInst->getOperand(0),
419+
addInst->getOperand(1), "add", addInst);
420+
421+
// Replace all uses of the add instruction
422+
addInst->replaceAllUsesWith(orInst);
423+
424+
// Delete the original add instruction
425+
addInst->eraseFromParent();
426+
orInst->setName("add");
427+
}
428+
}
429+
409430
// Helper function to update predecessors to point to a new preheader
410431
static void updatePredecessorsToPreheader(BasicBlock *ForBody,
411432
BasicBlock *ForBodyPreheader) {
@@ -1151,7 +1172,7 @@ static Value *expandForCondPreheader(
11511172
}
11521173

11531174
// Get the icmp instruction in ForCondPreheader
1154-
ICmpInst *icmpInst = getFirstICmpInst(ForCondPreheader);
1175+
ICmpInst *icmpInst = getFirstInst<ICmpInst>(ForCondPreheader);
11551176

11561177
// Ensure we found the icmp instruction
11571178
assert(icmpInst && "Failed to find icmp instruction in ForCondPreheader");
@@ -1278,7 +1299,7 @@ static void insertUnusedInstructionsBeforeIcmp(PHINode *phiI32InClonedForBody,
12781299

12791300
static void modifyClonedForBody(BasicBlock *ClonedForBody) {
12801301

1281-
ICmpInst *lastIcmpEq = getLastICmpInst(ClonedForBody);
1302+
ICmpInst *lastIcmpEq = getLastInst<ICmpInst>(ClonedForBody);
12821303
assert(lastIcmpEq &&
12831304
"Failed to find last icmp eq instruction in ClonedForBody");
12841305

@@ -1472,7 +1493,7 @@ static void modifyForCondPreheader2(BasicBlock *ClonedForBody,
14721493
}
14731494

14741495
// Find operand 1 of the icmp instruction from ClonedForBody
1475-
ICmpInst *firstIcmp = getFirstICmpInst(ClonedForBody);
1496+
ICmpInst *firstIcmp = getFirstInst<ICmpInst>(ClonedForBody);
14761497
assert(firstIcmp && "Unable to find icmp instruction in ClonedForBody");
14771498
Value *IcmpOperand1 = firstIcmp->getOperand(1);
14781499

@@ -1549,7 +1570,7 @@ static void modifyForCondPreheader2(BasicBlock *ClonedForBody,
15491570

15501571
static Value *modifyClonedForBodyPreheader(BasicBlock *ClonedForBodyPreheader,
15511572
BasicBlock *ForBody) {
1552-
ICmpInst *firstIcmp = getFirstICmpInst(ForBody);
1573+
ICmpInst *firstIcmp = getFirstInst<ICmpInst>(ForBody);
15531574
assert(firstIcmp && "Unable to find icmp instruction in ForBody");
15541575

15551576
Value *IcmpOperand1 = firstIcmp->getOperand(1);
@@ -2011,35 +2032,27 @@ static Instruction *modifyAddToOrInClonedForBody(BasicBlock *ClonedForBody) {
20112032
return orInst;
20122033
}
20132034

2014-
static void modifyAddToOr(BasicBlock *ClonedForBody) {
2015-
SmallVector<BinaryOperator *> addInsts;
2035+
static void runInstCombinePass(Function &F) {
2036+
// Create necessary analysis managers
2037+
LoopAnalysisManager LAM;
2038+
FunctionAnalysisManager FAM;
2039+
CGSCCAnalysisManager CGAM;
2040+
ModuleAnalysisManager MAM;
20162041

2017-
// Collect all add instructions that meet the criteria
2018-
for (auto &I : *ClonedForBody) {
2019-
if (auto *binOp = dyn_cast<BinaryOperator>(&I)) {
2020-
if (binOp->getOpcode() == Instruction::Add) {
2021-
addInsts.push_back(binOp);
2022-
}
2023-
}
2024-
}
2025-
if (addInsts.empty()) {
2026-
return;
2027-
}
2028-
// Replace each add instruction with an or disjoint instruction
2029-
for (auto it = addInsts.begin(); it != std::prev(addInsts.end()); ++it) {
2030-
auto *addInst = *it;
2031-
// Create a new or disjoint instruction
2032-
Instruction *orInst =
2033-
BinaryOperator::CreateDisjoint(Instruction::Or, addInst->getOperand(0),
2034-
addInst->getOperand(1), "add", addInst);
2042+
// Create pass builder
2043+
PassBuilder PB;
20352044

2036-
// Replace all uses of the add instruction
2037-
addInst->replaceAllUsesWith(orInst);
2045+
// Register analyses
2046+
PB.registerModuleAnalyses(MAM);
2047+
PB.registerCGSCCAnalyses(CGAM);
2048+
PB.registerFunctionAnalyses(FAM);
2049+
PB.registerLoopAnalyses(LAM);
2050+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
20382051

2039-
// Delete the original add instruction
2040-
addInst->eraseFromParent();
2041-
orInst->setName("add");
2042-
}
2052+
// Create function-level optimization pipeline
2053+
FunctionPassManager FPM;
2054+
FPM.addPass(InstCombinePass());
2055+
FPM.run(F, FAM);
20432056
}
20442057

20452058
static Value *unrolladdcClonedForBody(BasicBlock *ClonedForBody,
@@ -2058,7 +2071,7 @@ static Value *unrolladdcClonedForBody(BasicBlock *ClonedForBody,
20582071
assert(firstNonPHI && orInst && "Start or end instruction not found");
20592072

20602073
// Find the icmp instruction
2061-
Instruction *icmpInst = getFirstICmpInst(ClonedForBody);
2074+
Instruction *icmpInst = getFirstInst<ICmpInst>(ClonedForBody);
20622075

20632076
// Ensure that the icmp instruction is found
20642077
assert(icmpInst && "icmp instruction not found");
@@ -2298,7 +2311,7 @@ static void unrollAddc(Function &F, ScalarEvolution &SE, Loop *L,
22982311
assert(ForCondPreheader && "Expected to find for.cond.preheader!");
22992312
expandForCondPreheaderaddc(F, ForCondPreheader, ClonedForBody, ForBody, sub,
23002313
unroll_factor);
2301-
modifyAddToOr(ClonedForBody);
2314+
runInstCombinePass(F);
23022315
groupAndReorderInstructions(ClonedForBody);
23032316

23042317
// Verify the function
@@ -2816,11 +2829,11 @@ static void postUnrollLoopWithCount(Function &F, Loop *L, int unroll_count) {
28162829
insertPhiNodesForFMulAdd(LoopHeader, LoopPreheader, FMulAddCalls);
28172830

28182831
movePHINodesToTop(*LoopHeader);
2819-
modifyAddToOr(LoopHeader);
2832+
runInstCombinePass(F);
28202833
groupAndReorderInstructions(LoopHeader);
28212834

28222835
// Create for.end basic block after LoopHeader
2823-
ICmpInst *LastICmp = getLastICmpInst(LoopHeader);
2836+
ICmpInst *LastICmp = getLastInst<ICmpInst>(LoopHeader);
28242837
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
28252838
// Get the first operand of LastICmp
28262839
Value *Operand1 = LastICmp->getOperand(1);
@@ -3023,7 +3036,7 @@ static bool shouldUnrollDotprodType(Function &F, LoopInfo *LI) {
30233036
}
30243037

30253038
static std::pair<Value *, Value *> modifyEntryBB(BasicBlock &entryBB) {
3026-
ICmpInst *icmp = getLastICmpInst(&entryBB);
3039+
ICmpInst *icmp = getLastInst<ICmpInst>(&entryBB);
30273040
assert(icmp && "icmp not found");
30283041
Value *start_index = icmp->getOperand(0);
30293042
Value *end_index = icmp->getOperand(1);
@@ -3115,7 +3128,7 @@ static void postUnrollLoopWithVariable(Function &F, Loop *L, int unroll_count) {
31153128
temp->insertBefore(LoopPreheader->getTerminator());
31163129
}
31173130

3118-
ICmpInst *lastICmp = getLastICmpInst(ForBody7);
3131+
ICmpInst *lastICmp = getLastInst<ICmpInst>(ForBody7);
31193132
assert(lastICmp && "icmp not found");
31203133
lastICmp->setOperand(1, Sub);
31213134
lastICmp->setPredicate(ICmpInst::ICMP_SLT);
@@ -3552,7 +3565,7 @@ static std::tuple<Value *, GetElementPtrInst *, Value *>
35523565
modifyOuterLoop4(Loop *L, BasicBlock *ForBodyMerged,
35533566
BasicBlock *CloneForBodyPreheader) {
35543567
BasicBlock *BB = L->getHeader();
3555-
PHINode *phi = getLastPhi(BB);
3568+
PHINode *phi = getLastInst<PHINode>(BB);
35563569
// Add new instructions
35573570
IRBuilder<> Builder(BB);
35583571
Builder.SetInsertPoint(phi->getNextNode());
@@ -3596,7 +3609,7 @@ static void modifyInnerLoop4(Loop *L, BasicBlock *ForBodyMerged, Value *Sub,
35963609
movePHINodesToTop(*ForBodyMerged);
35973610

35983611
groupAndReorderInstructions(ForBodyMerged);
3599-
ICmpInst *LastICmp = getLastICmpInst(ForBodyMerged);
3612+
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForBodyMerged);
36003613
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
36013614
LastICmp->setOperand(1, Sub);
36023615
swapTerminatorSuccessors(ForBodyMerged);
@@ -3653,7 +3666,8 @@ static void modifyInnerLoop4(Loop *L, BasicBlock *ForBodyMerged, Value *Sub,
36533666
AddPHI->addIncoming(Add2, NewForEnd);
36543667
Value *phifloatincomingvalue0 =
36553668
getFirstCallInstWithName(CloneForBody, "llvm.fmuladd.f32");
3656-
Value *phii32incomingvalue0 = getLastICmpInst(CloneForBody)->getOperand(0);
3669+
Value *phii32incomingvalue0 =
3670+
getLastInst<ICmpInst>(CloneForBody)->getOperand(0);
36573671
for (PHINode &Phi : CloneForBody->phis()) {
36583672
if (Phi.getType()->isIntegerTy(32)) {
36593673
Phi.setIncomingValue(0, AddPHI);
@@ -3676,7 +3690,7 @@ static void modifyInnerLoop4(Loop *L, BasicBlock *ForBodyMerged, Value *Sub,
36763690
static std::tuple<Value *, Value *, GetElementPtrInst *>
36773691
modifyOuterLoop8(Loop *L) {
36783692
BasicBlock *BB = L->getHeader();
3679-
ICmpInst *LastICmp = getLastICmpInst(BB);
3693+
ICmpInst *LastICmp = getLastInst<ICmpInst>(BB);
36803694
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
36813695
swapTerminatorSuccessors(BB);
36823696

@@ -3714,7 +3728,7 @@ static std::tuple<Value *, Value *, GetElementPtrInst *>
37143728
modifyOuterLoop16(Loop *L) {
37153729
BasicBlock *BB = L->getHeader();
37163730
BasicBlock *BBLoopPreHeader = L->getLoopPreheader();
3717-
ICmpInst *LastICmp = getLastICmpInst(BB);
3731+
ICmpInst *LastICmp = getLastInst<ICmpInst>(BB);
37183732
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
37193733
swapTerminatorSuccessors(BB);
37203734

@@ -3763,7 +3777,7 @@ static void modifyInnerLoop(Loop *L, BasicBlock *ForBodyMerged, Value *Add60,
37633777
movePHINodesToTop(*ForBodyMerged);
37643778

37653779
groupAndReorderInstructions(ForBodyMerged);
3766-
ICmpInst *LastICmp = getLastICmpInst(ForBodyMerged);
3780+
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForBodyMerged);
37673781
LastICmp->setPredicate(ICmpInst::ICMP_ULT);
37683782
LastICmp->setOperand(1, Add60);
37693783
swapTerminatorSuccessors(ForBodyMerged);
@@ -3873,7 +3887,7 @@ static void modifyInnerLoop(Loop *L, BasicBlock *ForBodyMerged, Value *Add60,
38733887

38743888
Value *operand1 = unroll_count == 16
38753889
? getFirstI32Phi(OuterBB)
3876-
: getLastICmpInst(CloneForBody)->getOperand(1);
3890+
: getLastInst<ICmpInst>(CloneForBody)->getOperand(1);
38773891
// Create a new comparison instruction
38783892
ICmpInst *NewCmp =
38793893
new ICmpInst(ICmpInst::ICMP_UGT, PhiSum, operand1, "cmp182.not587");
@@ -3890,7 +3904,8 @@ static void modifyInnerLoop(Loop *L, BasicBlock *ForBodyMerged, Value *Add60,
38903904
getFirstCallInstWithName(CloneForBody, "llvm.fmuladd.f32");
38913905
for (PHINode &Phi : CloneForBody->phis()) {
38923906
if (Phi.getType()->isIntegerTy(32)) {
3893-
Phi.setIncomingValue(0, getLastICmpInst(CloneForBody)->getOperand(0));
3907+
Phi.setIncomingValue(0,
3908+
getLastInst<ICmpInst>(CloneForBody)->getOperand(0));
38943909
Phi.setIncomingBlock(0, CloneForBody);
38953910
Phi.setIncomingValue(1, PhiSum);
38963911
Phi.setIncomingBlock(1, ForEnd164);
@@ -3981,7 +3996,7 @@ static void modifyFirstCloneForBody(BasicBlock *CloneForBody,
39813996
lastAddInst = &I;
39823997
}
39833998
}
3984-
ICmpInst *LastCmpInst = getLastICmpInst(CloneForBody);
3999+
ICmpInst *LastCmpInst = getLastInst<ICmpInst>(CloneForBody);
39854000
LastCmpInst->setOperand(0, lastAddInst);
39864001
LastCmpInst->setOperand(1, Operand1);
39874002
FirstI32Phi->setIncomingValue(1, lastAddInst);
@@ -4045,7 +4060,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
40454060
getFirstI32Phi(ForCond23Preheader)->getIncomingBlock(0);
40464061
Instruction *FirstI32Phi = getFirstI32Phi(ForCondCleanup3);
40474062

4048-
ICmpInst *LastICmp = getLastICmpInst(ForCondCleanup3);
4063+
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForCondCleanup3);
40494064
// Create new add instruction
40504065
IRBuilder<> Builder(LastICmp);
40514066
Value *Add269 = Builder.CreateNSWAdd(
@@ -4067,7 +4082,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
40674082

40684083
N_069->setIncomingValue(1, Add281);
40694084

4070-
ICmpInst *LastICmpInPreheader = getLastICmpInst(ForCond23Preheader);
4085+
ICmpInst *LastICmpInPreheader = getLastInst<ICmpInst>(ForCond23Preheader);
40714086
// Create new phi node
40724087
PHINode *N_0_lcssa = PHINode::Create(Type::getInt32Ty(F.getContext()), 2,
40734088
"n.0.lcssa", LastICmpInPreheader);
@@ -4093,7 +4108,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
40934108
Value *Add11 = Builder.CreateAdd(Operand1, CoeffPosLcssa);
40944109

40954110
ForBody27LrPh->getTerminator()->setSuccessor(0, CloneForBody);
4096-
ICmpInst *LastICmpInForBodyMerged = getLastICmpInst(ForBodyMerged);
4111+
ICmpInst *LastICmpInForBodyMerged = getLastInst<ICmpInst>(ForBodyMerged);
40974112
LastICmpInForBodyMerged->setOperand(1, Operand1);
40984113
LastICmpInForBodyMerged->setOperand(0, Inc20_7);
40994114

@@ -4159,9 +4174,8 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
41594174
CI->setOperand(2, PHI);
41604175
}
41614176
movePHINodesToTop(*ForBodyMerged);
4162-
modifyAddToOr(ForBodyMerged);
4163-
4164-
ICmpInst *LastICmpForBodyMerged = getLastICmpInst(ForBodyMerged);
4177+
modifyFirdAddToOr(ForBodyMerged);
4178+
ICmpInst *LastICmpForBodyMerged = getLastInst<ICmpInst>(ForBodyMerged);
41654179
LastICmpForBodyMerged->setPredicate(ICmpInst::ICMP_SGT);
41664180
cast<Instruction>(LastICmpForBodyMerged->getOperand(0))
41674181
->setOperand(0, getFirstI32Phi(ForBodyMerged));
@@ -4256,7 +4270,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
42564270
CoeffPosLcssaPhi->addIncoming(SubResult, ForCondCleanup26LoopExit);
42574271
// eraseAllStoreInstInBB(ForCondCleanup26);
42584272

4259-
ICmpInst *LastICmpForCondCleanup26 = getLastICmpInst(ForCondCleanup26);
4273+
ICmpInst *LastICmpForCondCleanup26 = getLastInst<ICmpInst>(ForCondCleanup26);
42604274

42614275
LastICmpForCondCleanup26->setPredicate(ICmpInst::ICMP_SLT);
42624276
PHINode *FirstI32ForCondCleanup3 = getFirstI32Phi(ForCondCleanup3);
@@ -4314,7 +4328,7 @@ static void modifyFirdFirstLoop(Function &F, Loop *L, BasicBlock *ForBodyMerged,
43144328
0, ConstantInt::get(getLastI32Phi(ForCond130Preheader)->getType(), 0));
43154329
LastI32Phi130->setIncomingValue(1, AndResult);
43164330

4317-
ICmpInst *LastICmp130 = getLastICmpInst(ForCond130Preheader);
4331+
ICmpInst *LastICmp130 = getLastInst<ICmpInst>(ForCond130Preheader);
43184332
LastICmp130->setOperand(1, FirstI32ForCondCleanup3);
43194333

43204334
PHINode *LastI32PhiClone = getLastFloatPhi(CloneForBody);
@@ -4434,9 +4448,8 @@ static void modifyFirdSecondLoop(Function &F, Loop *L,
44344448
Add76310->addIncoming(Add76, ForBodyMerged);
44354449

44364450
movePHINodesToTop(*ForBodyMerged);
4437-
modifyAddToOr(ForBodyMerged);
4438-
4439-
ICmpInst *LastICmp = getLastICmpInst(ForBodyMerged);
4451+
modifyFirdAddToOr(ForBodyMerged);
4452+
ICmpInst *LastICmp = getLastInst<ICmpInst>(ForBodyMerged);
44404453
LastICmp->setPredicate(ICmpInst::ICMP_SGT);
44414454
cast<Instruction>(Add76)->moveBefore(LastICmp);
44424455
LastICmp->setOperand(0, Add76);
@@ -5043,6 +5056,7 @@ RISCVLoopUnrollAndRemainderPass::run(Function &F, FunctionAnalysisManager &AM) {
50435056
if (currentUnrollType == UnrollType::FIRD) {
50445057
addLegacyCommonOptimizationPasses(F);
50455058
}
5059+
50465060
// Verify function
50475061
if (verifyFunction(F, &errs())) {
50485062
LLVM_DEBUG(errs() << "Function verification failed\n");

0 commit comments

Comments
 (0)