Skip to content

Commit

Permalink
Use LHS components as synthesis components
Browse files Browse the repository at this point in the history
  • Loading branch information
regehr authored and rsas committed May 11, 2018
1 parent bfc9d26 commit 874065b
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 46 deletions.
62 changes: 34 additions & 28 deletions include/souper/Infer/InstSynthesis.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ typedef std::pair<unsigned, unsigned> LocVar;
typedef std::pair<LocVar, Inst *> LocInst;

/// A component is a fixed-width instruction kind
/// or created from Origin
struct Component {
Inst::Kind Kind;
unsigned Width;
std::vector<unsigned> OpWidths;
Inst *Origin;
std::vector<Inst *> OriginOps;
};

/// Unsupported components kinds
Expand All @@ -94,35 +97,35 @@ static const std::set<Inst::Kind> UnsupportedCompKinds = {
/// a component of that width is instantiated.
/// Again, note that constants are treated as ordinary inputs
static const std::vector<Component> CompLibrary = {
Component{Inst::Add, 0, {0,0}},
Component{Inst::Sub, 0, {0,0}},
Component{Inst::Mul, 0, {0,0}},
Component{Inst::UDiv, 0, {0,0}},
Component{Inst::SDiv, 0, {0,0}},
Component{Inst::UDivExact, 0, {0,0}},
Component{Inst::SDivExact, 0, {0,0}},
Component{Inst::URem, 0, {0,0}},
Component{Inst::SRem, 0, {0,0}},
Component{Inst::And, 0, {0,0}},
Component{Inst::Or, 0, {0,0}},
Component{Inst::Xor, 0, {0,0}},
Component{Inst::Shl, 0, {0,0}},
Component{Inst::LShr, 0, {0,0}},
Component{Inst::LShrExact, 0, {0,0}},
Component{Inst::AShr, 0, {0,0}},
Component{Inst::AShrExact, 0, {0,0}},
Component{Inst::Select, 0, {1,0,0}},
Component{Inst::Eq, 1, {0,0}},
Component{Inst::Ne, 1, {0,0}},
Component{Inst::Ult, 1, {0,0}},
Component{Inst::Slt, 1, {0,0}},
Component{Inst::Ule, 1, {0,0}},
Component{Inst::Sle, 1, {0,0}},
Component{Inst::Add, 0, {0,0}, 0, {}},
Component{Inst::Sub, 0, {0,0}, 0, {}},
Component{Inst::Mul, 0, {0,0}, 0, {}},
Component{Inst::UDiv, 0, {0,0}, 0, {}},
Component{Inst::SDiv, 0, {0,0}, 0, {}},
Component{Inst::UDivExact, 0, {0,0}, 0, {}},
Component{Inst::SDivExact, 0, {0,0}, 0, {}},
Component{Inst::URem, 0, {0,0}, 0, {}},
Component{Inst::SRem, 0, {0,0}, 0, {}},
Component{Inst::And, 0, {0,0}, 0, {}},
Component{Inst::Or, 0, {0,0}, 0, {}},
Component{Inst::Xor, 0, {0,0}, 0, {}},
Component{Inst::Shl, 0, {0,0}, 0, {}},
Component{Inst::LShr, 0, {0,0}, 0, {}},
Component{Inst::LShrExact, 0, {0,0}, 0, {}},
Component{Inst::AShr, 0, {0,0}, 0, {}},
Component{Inst::AShrExact, 0, {0,0}, 0, {}},
Component{Inst::Select, 0, {1,0,0}, 0, {}},
Component{Inst::Eq, 1, {0,0}, 0, {}},
Component{Inst::Ne, 1, {0,0}, 0, {}},
Component{Inst::Ult, 1, {0,0}, 0, {}},
Component{Inst::Slt, 1, {0,0}, 0, {}},
Component{Inst::Ule, 1, {0,0}, 0, {}},
Component{Inst::Sle, 1, {0,0}, 0, {}},
//
Component{Inst::CtPop, 0, {0}},
Component{Inst::BSwap, 0, {0}},
Component{Inst::Cttz, 0, {0}},
Component{Inst::Ctlz, 0, {0}}
Component{Inst::CtPop, 0, {0}, 0, {}},
Component{Inst::BSwap, 0, {0}, 0, {}},
Component{Inst::Cttz, 0, {0}, 0, {}},
Component{Inst::Ctlz, 0, {0}, 0, {}}
};

class InstSynthesis {
Expand All @@ -132,13 +135,15 @@ class InstSynthesis {
const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
Inst *TargetLHS, Inst *&RHS,
const std::vector<Inst *> &LHSComps,
InstContext &IC, unsigned Timeout);

private:
/// Local references
SMTLIBSolver *LSMTSolver;
const BlockPCs *LBPCs;
const std::vector<InstMapping> *LPCs;
const std::vector<Inst *> *LLHSComps;
InstContext *LIC;
unsigned LTimeout;

Expand Down Expand Up @@ -291,6 +296,7 @@ class InstSynthesis {

/// Helper functions
void filterFixedWidthIntrinsicComps();
Component getCompFromInst(Inst *);
void getInputVars(Inst *I, std::vector<Inst *> &InputVars);
std::string getLocVarStr(const LocVar &Loc, const std::string Prefix="");
LocVar getLocVarFromStr(const std::string &Str);
Expand Down
5 changes: 4 additions & 1 deletion lib/Extractor/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,11 @@ class BaseSolver : public Solver {
}

if (InferInsts && SMTSolver->supportsModels()) {
std::vector<Inst *> LHSComps;
findCands(LHS, LHSComps, IC, MaxNops);
InstSynthesis IS;
EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout);
EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS,
LHSComps, IC, Timeout);
if (EC || RHS)
return EC;
}
Expand Down
95 changes: 78 additions & 17 deletions lib/Infer/InstSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver,
const BlockPCs &BPCs,
const std::vector<InstMapping> &PCs,
Inst *TargetLHS, Inst *&RHS,
const std::vector<Inst *> &LHSComps,
InstContext &IC, unsigned Timeout) {
std::error_code EC;

// init local refs
LSMTSolver = SMTSolver;
LBPCs = &BPCs;
LPCs = &PCs;
LLHSComps = &LHSComps;
LIC = &IC;
LTimeout = Timeout;

Expand All @@ -91,7 +93,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver,

if (DebugLevel > 0) {
llvm::outs() << "; starting synthesis for LHS\n";
PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context);
PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context, true);
if (DebugLevel > 2)
printInitInfo();
}
Expand Down Expand Up @@ -322,7 +324,7 @@ void InstSynthesis::setCompLibrary() {
for (auto KindStr : splitString(CmdUserCompKinds.c_str())) {
Inst::Kind K = Inst::getKind(KindStr);
if (KindStr == Inst::getKindName(Inst::Const)) // Special case
InitConstComps.push_back(Component{Inst::Const, 0, {}});
InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}});
else if (K == Inst::ZExt || K == Inst::SExt || K == Inst::Trunc)
report_fatal_error("don't use zext/sext/trunc explicitly");
else if (K == Inst::None)
Expand All @@ -338,13 +340,13 @@ void InstSynthesis::setCompLibrary() {
InitComps.push_back(Comp);
} else {
InitComps = CompLibrary;
InitConstComps.push_back(Component{Inst::Const, 0, {}});
InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}});
}
for (auto const &In : Inputs) {
if (In->Width == DefaultWidth)
continue;
Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}});
Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}});
Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}, 0, {}});
Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}, 0, {}});
}
// Second, for each input/constant create a component of DefaultWidth
for (auto &Comp : InitComps) {
Expand All @@ -362,7 +364,23 @@ void InstSynthesis::setCompLibrary() {
}
// Third, create one trunc comp to match the output width if necessary
if (LHS->Width < DefaultWidth)
Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}});
Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}, 0, {}});
// Finally, add LHS components (if provided) directly to Comps,
// their widths are already initialized.
for (auto I : *LLHSComps) {
// No support for the following Insts
switch (I->K) {
case Inst::Phi:
// TODO: Why do we get these as candidates?!
case Inst::Var:
case Inst::Const:
case Inst::UntypedConst:
continue;
default:
break;
}
Comps.push_back(getCompFromInst(I));
}
}

void InstSynthesis::initInputVars(InstContext &IC) {
Expand Down Expand Up @@ -438,10 +456,11 @@ void InstSynthesis::filterFixedWidthIntrinsicComps() {

void InstSynthesis::initComponents(InstContext &IC) {
for (unsigned J = 0; J < Comps.size(); ++J) {
auto const &Comp = Comps[J];
auto &Comp = Comps[J];
std::string LocVarStr;
// First, init component inputs
std::vector<Inst *> CompOps;
std::map<Inst *, Inst *> OpsReplacements;
std::vector<LocVar> OpsLocVar;
for (unsigned K = 0; K < Comp.OpWidths.size(); ++K) {
LocVar In = std::make_pair(J+1, K+1);
Expand All @@ -464,6 +483,11 @@ void InstSynthesis::initComponents(InstContext &IC) {
CompInstMap[In] = OpInst;
CompOps.push_back(OpInst);
OpsLocVar.push_back(In);
// Update OpsReplacements
if (Comp.Origin) {
assert(Comp.OriginOps.size());
OpsReplacements.insert(std::make_pair(Comp.OriginOps[K], OpInst));
}
}
// Store all input locations
CompOpLocVars.push_back(OpsLocVar);
Expand All @@ -479,13 +503,23 @@ void InstSynthesis::initComponents(InstContext &IC) {
// Third, instantiate the component (aka Inst)
assert(Comp.Width && "comp width not set");
Inst *CompInst;
if (Comp.Kind == Inst::Select) {
Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]});
CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]});
} else {
CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps);
if (Comp.Origin) {
assert(Comp.OriginOps.size() == CompOps.size());
CompInst = getInstCopy(Comp.Origin, *LIC, OpsReplacements);
if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc)
CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst});
// Update LHS component
Comp.Origin = CompInst;
Comp.OriginOps = CompOps;
} else {
if (Comp.Kind == Inst::Select) {
Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]});
CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]});
} else {
CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps);
if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc)
CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst});
}
}
// Update CompInstMap map with concrete Inst
CompInstMap[Out] = CompInst;
Expand Down Expand Up @@ -517,12 +551,14 @@ void InstSynthesis::printInitInfo() {
llvm::outs() << "N: " << N << ", M: " << M << "\n";
llvm::outs() << "default width: " << DefaultWidth << "\n";
llvm::outs() << "output width: " << LHS->Width << "\n";
llvm::outs() << "component library: ";
llvm::outs() << "component library: " << Comps.size() << "\n";
for (auto const &Comp : Comps) {
llvm::outs() << Inst::getKindName(Comp.Kind) << " (" << Comp.Width << ", { ";
for (auto const &Width : Comp.OpWidths)
llvm::outs() << Width << " ";
llvm::outs() << "}); ";
llvm::outs() << "})\n";
if (Comp.Origin)
PrintReplacementRHS(llvm::outs(), Comp.Origin, Context, true);
}
if (Comps.size())
llvm::outs() << "\n";
Expand Down Expand Up @@ -980,15 +1016,28 @@ Inst *InstSynthesis::createInstFromWiring(
llvm::outs() << "- creating inst " << Inst::getKindName(Comp.Kind)
<< ", width " << Comp.Width << "\n";
llvm::outs() << "before junk removal:\n";
PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops),
Context);
if (Comp.Origin)
PrintReplacementRHS(llvm::outs(), Comp.Origin, Context);
else
PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops),
Context);
}
// Sanity checks
if (Ops.size() == 2 && Ops[0]->K == Inst::Const && Ops[1]->K == Inst::Const)
report_fatal_error("inst operands are constants!");
assert(Comp.Width == 1 || Comp.Width == DefaultWidth ||
Comp.Width == LHS->Width);
// Create instruction
// Instruction is a LHS component
if (Comp.Origin) {
assert(Comp.OriginOps.size() == Ops.size());
std::map<Inst *, Inst *> OpsReplacements;
for (unsigned J = 0; J < Ops.size(); ++J)
OpsReplacements.insert(std::make_pair(Comp.OriginOps[J], Ops[J]));
Inst *Copy = getInstCopy(Comp.Origin, *LIC, OpsReplacements);
// Update ops
Ops = Copy->Ops;
}
// Create instruction from a component
if (Comp.Kind == Inst::Select) {
Ops[0] = IC.getInst(Inst::Trunc, 1, {Ops[0]});
return createCleanInst(Comp.Kind, Comp.Width, Ops, IC);
Expand Down Expand Up @@ -1214,6 +1263,18 @@ Inst *InstSynthesis::createCleanInst(Inst::Kind Kind, unsigned Width,
return IC.getInst(Kind, Width, Ops);
}

Component InstSynthesis::getCompFromInst(Inst *I) {
std::vector<Inst *> IV;
getInputVars(I, IV);
sort(IV.begin(), IV.end());
IV.erase(unique(IV.begin(), IV.end()), IV.end());
std::vector<unsigned> OpWidths;
for (auto In : IV)
OpWidths.push_back(In->Width);

return Component{I->K, I->Width, OpWidths, I, IV};
}

void InstSynthesis::getInputVars(Inst *I, std::vector<Inst *> &InputVars) {
if (I->K == Inst::Var)
InputVars.push_back(I);
Expand Down
14 changes: 14 additions & 0 deletions test/Infer/four-adds.opt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; REQUIRES: solver, solver-model

; -souper-synthesis-comps=const is just a hack to avoid the initialization of the whole component library
; RUN: %souper-check %solver -infer-rhs -souper-infer-inst -souper-synthesis-comps=const -souper-synthesis-ignore-cost %s > %t1
; RUN: %FileCheck %s < %t1

; CHECK: result %4

%0:i32 = var
%1:i32 = add 1:i32, %0
%2:i32 = add 1:i32, %1
%3:i32 = add 1:i32, %2
%4:i32 = add 1:i32, %3
infer %4

0 comments on commit 874065b

Please sign in to comment.