From 874065bafaad2c526df13f1bb6cf748f0c365130 Mon Sep 17 00:00:00 2001 From: John Regehr Date: Mon, 7 May 2018 17:07:32 -0600 Subject: [PATCH] Use LHS components as synthesis components --- include/souper/Infer/InstSynthesis.h | 62 ++++++++++-------- lib/Extractor/Solver.cpp | 5 +- lib/Infer/InstSynthesis.cpp | 95 +++++++++++++++++++++++----- test/Infer/four-adds.opt | 14 ++++ 4 files changed, 130 insertions(+), 46 deletions(-) create mode 100644 test/Infer/four-adds.opt diff --git a/include/souper/Infer/InstSynthesis.h b/include/souper/Infer/InstSynthesis.h index 1db4100e8..e445ff90f 100644 --- a/include/souper/Infer/InstSynthesis.h +++ b/include/souper/Infer/InstSynthesis.h @@ -70,10 +70,13 @@ typedef std::pair LocVar; typedef std::pair LocInst; /// A component is a fixed-width instruction kind +/// or created from Origin struct Component { Inst::Kind Kind; unsigned Width; std::vector OpWidths; + Inst *Origin; + std::vector OriginOps; }; /// Unsupported components kinds @@ -94,35 +97,35 @@ static const std::set UnsupportedCompKinds = { /// a component of that width is instantiated. /// Again, note that constants are treated as ordinary inputs static const std::vector CompLibrary = { - Component{Inst::Add, 0, {0,0}}, - Component{Inst::Sub, 0, {0,0}}, - Component{Inst::Mul, 0, {0,0}}, - Component{Inst::UDiv, 0, {0,0}}, - Component{Inst::SDiv, 0, {0,0}}, - Component{Inst::UDivExact, 0, {0,0}}, - Component{Inst::SDivExact, 0, {0,0}}, - Component{Inst::URem, 0, {0,0}}, - Component{Inst::SRem, 0, {0,0}}, - Component{Inst::And, 0, {0,0}}, - Component{Inst::Or, 0, {0,0}}, - Component{Inst::Xor, 0, {0,0}}, - Component{Inst::Shl, 0, {0,0}}, - Component{Inst::LShr, 0, {0,0}}, - Component{Inst::LShrExact, 0, {0,0}}, - Component{Inst::AShr, 0, {0,0}}, - Component{Inst::AShrExact, 0, {0,0}}, - Component{Inst::Select, 0, {1,0,0}}, - Component{Inst::Eq, 1, {0,0}}, - Component{Inst::Ne, 1, {0,0}}, - Component{Inst::Ult, 1, {0,0}}, - Component{Inst::Slt, 1, {0,0}}, - Component{Inst::Ule, 1, {0,0}}, - Component{Inst::Sle, 1, {0,0}}, + Component{Inst::Add, 0, {0,0}, 0, {}}, + Component{Inst::Sub, 0, {0,0}, 0, {}}, + Component{Inst::Mul, 0, {0,0}, 0, {}}, + Component{Inst::UDiv, 0, {0,0}, 0, {}}, + Component{Inst::SDiv, 0, {0,0}, 0, {}}, + Component{Inst::UDivExact, 0, {0,0}, 0, {}}, + Component{Inst::SDivExact, 0, {0,0}, 0, {}}, + Component{Inst::URem, 0, {0,0}, 0, {}}, + Component{Inst::SRem, 0, {0,0}, 0, {}}, + Component{Inst::And, 0, {0,0}, 0, {}}, + Component{Inst::Or, 0, {0,0}, 0, {}}, + Component{Inst::Xor, 0, {0,0}, 0, {}}, + Component{Inst::Shl, 0, {0,0}, 0, {}}, + Component{Inst::LShr, 0, {0,0}, 0, {}}, + Component{Inst::LShrExact, 0, {0,0}, 0, {}}, + Component{Inst::AShr, 0, {0,0}, 0, {}}, + Component{Inst::AShrExact, 0, {0,0}, 0, {}}, + Component{Inst::Select, 0, {1,0,0}, 0, {}}, + Component{Inst::Eq, 1, {0,0}, 0, {}}, + Component{Inst::Ne, 1, {0,0}, 0, {}}, + Component{Inst::Ult, 1, {0,0}, 0, {}}, + Component{Inst::Slt, 1, {0,0}, 0, {}}, + Component{Inst::Ule, 1, {0,0}, 0, {}}, + Component{Inst::Sle, 1, {0,0}, 0, {}}, // - Component{Inst::CtPop, 0, {0}}, - Component{Inst::BSwap, 0, {0}}, - Component{Inst::Cttz, 0, {0}}, - Component{Inst::Ctlz, 0, {0}} + Component{Inst::CtPop, 0, {0}, 0, {}}, + Component{Inst::BSwap, 0, {0}, 0, {}}, + Component{Inst::Cttz, 0, {0}, 0, {}}, + Component{Inst::Ctlz, 0, {0}, 0, {}} }; class InstSynthesis { @@ -132,6 +135,7 @@ class InstSynthesis { const BlockPCs &BPCs, const std::vector &PCs, Inst *TargetLHS, Inst *&RHS, + const std::vector &LHSComps, InstContext &IC, unsigned Timeout); private: @@ -139,6 +143,7 @@ class InstSynthesis { SMTLIBSolver *LSMTSolver; const BlockPCs *LBPCs; const std::vector *LPCs; + const std::vector *LLHSComps; InstContext *LIC; unsigned LTimeout; @@ -291,6 +296,7 @@ class InstSynthesis { /// Helper functions void filterFixedWidthIntrinsicComps(); + Component getCompFromInst(Inst *); void getInputVars(Inst *I, std::vector &InputVars); std::string getLocVarStr(const LocVar &Loc, const std::string Prefix=""); LocVar getLocVarFromStr(const std::string &Str); diff --git a/lib/Extractor/Solver.cpp b/lib/Extractor/Solver.cpp index 3e0a7d50c..2dff53f8b 100644 --- a/lib/Extractor/Solver.cpp +++ b/lib/Extractor/Solver.cpp @@ -206,8 +206,11 @@ class BaseSolver : public Solver { } if (InferInsts && SMTSolver->supportsModels()) { + std::vector LHSComps; + findCands(LHS, LHSComps, IC, MaxNops); InstSynthesis IS; - EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout); + EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, + LHSComps, IC, Timeout); if (EC || RHS) return EC; } diff --git a/lib/Infer/InstSynthesis.cpp b/lib/Infer/InstSynthesis.cpp index a6281ecbb..42b609e57 100644 --- a/lib/Infer/InstSynthesis.cpp +++ b/lib/Infer/InstSynthesis.cpp @@ -59,6 +59,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver, const BlockPCs &BPCs, const std::vector &PCs, Inst *TargetLHS, Inst *&RHS, + const std::vector &LHSComps, InstContext &IC, unsigned Timeout) { std::error_code EC; @@ -66,6 +67,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver, LSMTSolver = SMTSolver; LBPCs = &BPCs; LPCs = &PCs; + LLHSComps = &LHSComps; LIC = &IC; LTimeout = Timeout; @@ -91,7 +93,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver, if (DebugLevel > 0) { llvm::outs() << "; starting synthesis for LHS\n"; - PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context); + PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context, true); if (DebugLevel > 2) printInitInfo(); } @@ -322,7 +324,7 @@ void InstSynthesis::setCompLibrary() { for (auto KindStr : splitString(CmdUserCompKinds.c_str())) { Inst::Kind K = Inst::getKind(KindStr); if (KindStr == Inst::getKindName(Inst::Const)) // Special case - InitConstComps.push_back(Component{Inst::Const, 0, {}}); + InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}}); else if (K == Inst::ZExt || K == Inst::SExt || K == Inst::Trunc) report_fatal_error("don't use zext/sext/trunc explicitly"); else if (K == Inst::None) @@ -338,13 +340,13 @@ void InstSynthesis::setCompLibrary() { InitComps.push_back(Comp); } else { InitComps = CompLibrary; - InitConstComps.push_back(Component{Inst::Const, 0, {}}); + InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}}); } for (auto const &In : Inputs) { if (In->Width == DefaultWidth) continue; - Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}}); - Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}}); + Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}, 0, {}}); + Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}, 0, {}}); } // Second, for each input/constant create a component of DefaultWidth for (auto &Comp : InitComps) { @@ -362,7 +364,23 @@ void InstSynthesis::setCompLibrary() { } // Third, create one trunc comp to match the output width if necessary if (LHS->Width < DefaultWidth) - Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}}); + Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}, 0, {}}); + // Finally, add LHS components (if provided) directly to Comps, + // their widths are already initialized. + for (auto I : *LLHSComps) { + // No support for the following Insts + switch (I->K) { + case Inst::Phi: + // TODO: Why do we get these as candidates?! + case Inst::Var: + case Inst::Const: + case Inst::UntypedConst: + continue; + default: + break; + } + Comps.push_back(getCompFromInst(I)); + } } void InstSynthesis::initInputVars(InstContext &IC) { @@ -438,10 +456,11 @@ void InstSynthesis::filterFixedWidthIntrinsicComps() { void InstSynthesis::initComponents(InstContext &IC) { for (unsigned J = 0; J < Comps.size(); ++J) { - auto const &Comp = Comps[J]; + auto &Comp = Comps[J]; std::string LocVarStr; // First, init component inputs std::vector CompOps; + std::map OpsReplacements; std::vector OpsLocVar; for (unsigned K = 0; K < Comp.OpWidths.size(); ++K) { LocVar In = std::make_pair(J+1, K+1); @@ -464,6 +483,11 @@ void InstSynthesis::initComponents(InstContext &IC) { CompInstMap[In] = OpInst; CompOps.push_back(OpInst); OpsLocVar.push_back(In); + // Update OpsReplacements + if (Comp.Origin) { + assert(Comp.OriginOps.size()); + OpsReplacements.insert(std::make_pair(Comp.OriginOps[K], OpInst)); + } } // Store all input locations CompOpLocVars.push_back(OpsLocVar); @@ -479,13 +503,23 @@ void InstSynthesis::initComponents(InstContext &IC) { // Third, instantiate the component (aka Inst) assert(Comp.Width && "comp width not set"); Inst *CompInst; - if (Comp.Kind == Inst::Select) { - Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]}); - CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]}); - } else { - CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps); + if (Comp.Origin) { + assert(Comp.OriginOps.size() == CompOps.size()); + CompInst = getInstCopy(Comp.Origin, *LIC, OpsReplacements); if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst}); + // Update LHS component + Comp.Origin = CompInst; + Comp.OriginOps = CompOps; + } else { + if (Comp.Kind == Inst::Select) { + Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]}); + CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]}); + } else { + CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps); + if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) + CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst}); + } } // Update CompInstMap map with concrete Inst CompInstMap[Out] = CompInst; @@ -517,12 +551,14 @@ void InstSynthesis::printInitInfo() { llvm::outs() << "N: " << N << ", M: " << M << "\n"; llvm::outs() << "default width: " << DefaultWidth << "\n"; llvm::outs() << "output width: " << LHS->Width << "\n"; - llvm::outs() << "component library: "; + llvm::outs() << "component library: " << Comps.size() << "\n"; for (auto const &Comp : Comps) { llvm::outs() << Inst::getKindName(Comp.Kind) << " (" << Comp.Width << ", { "; for (auto const &Width : Comp.OpWidths) llvm::outs() << Width << " "; - llvm::outs() << "}); "; + llvm::outs() << "})\n"; + if (Comp.Origin) + PrintReplacementRHS(llvm::outs(), Comp.Origin, Context, true); } if (Comps.size()) llvm::outs() << "\n"; @@ -980,15 +1016,28 @@ Inst *InstSynthesis::createInstFromWiring( llvm::outs() << "- creating inst " << Inst::getKindName(Comp.Kind) << ", width " << Comp.Width << "\n"; llvm::outs() << "before junk removal:\n"; - PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops), - Context); + if (Comp.Origin) + PrintReplacementRHS(llvm::outs(), Comp.Origin, Context); + else + PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops), + Context); } // Sanity checks if (Ops.size() == 2 && Ops[0]->K == Inst::Const && Ops[1]->K == Inst::Const) report_fatal_error("inst operands are constants!"); assert(Comp.Width == 1 || Comp.Width == DefaultWidth || Comp.Width == LHS->Width); - // Create instruction + // Instruction is a LHS component + if (Comp.Origin) { + assert(Comp.OriginOps.size() == Ops.size()); + std::map OpsReplacements; + for (unsigned J = 0; J < Ops.size(); ++J) + OpsReplacements.insert(std::make_pair(Comp.OriginOps[J], Ops[J])); + Inst *Copy = getInstCopy(Comp.Origin, *LIC, OpsReplacements); + // Update ops + Ops = Copy->Ops; + } + // Create instruction from a component if (Comp.Kind == Inst::Select) { Ops[0] = IC.getInst(Inst::Trunc, 1, {Ops[0]}); return createCleanInst(Comp.Kind, Comp.Width, Ops, IC); @@ -1214,6 +1263,18 @@ Inst *InstSynthesis::createCleanInst(Inst::Kind Kind, unsigned Width, return IC.getInst(Kind, Width, Ops); } +Component InstSynthesis::getCompFromInst(Inst *I) { + std::vector IV; + getInputVars(I, IV); + sort(IV.begin(), IV.end()); + IV.erase(unique(IV.begin(), IV.end()), IV.end()); + std::vector OpWidths; + for (auto In : IV) + OpWidths.push_back(In->Width); + + return Component{I->K, I->Width, OpWidths, I, IV}; +} + void InstSynthesis::getInputVars(Inst *I, std::vector &InputVars) { if (I->K == Inst::Var) InputVars.push_back(I); diff --git a/test/Infer/four-adds.opt b/test/Infer/four-adds.opt new file mode 100644 index 000000000..6bdb7b686 --- /dev/null +++ b/test/Infer/four-adds.opt @@ -0,0 +1,14 @@ +; REQUIRES: solver, solver-model + +; -souper-synthesis-comps=const is just a hack to avoid the initialization of the whole component library +; RUN: %souper-check %solver -infer-rhs -souper-infer-inst -souper-synthesis-comps=const -souper-synthesis-ignore-cost %s > %t1 +; RUN: %FileCheck %s < %t1 + +; CHECK: result %4 + +%0:i32 = var +%1:i32 = add 1:i32, %0 +%2:i32 = add 1:i32, %1 +%3:i32 = add 1:i32, %2 +%4:i32 = add 1:i32, %3 +infer %4