Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gurobi ILP solver for protocol selection #321

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions cli/src/main/kotlin/io/github/apl_cornell/viaduct/cli/Compile.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package io.github.apl_cornell.viaduct.cli

import com.github.ajalt.clikt.core.CliktCommand
import com.github.ajalt.clikt.parameters.options.default
import com.github.ajalt.clikt.parameters.options.flag
import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.types.choice
import com.github.ajalt.clikt.parameters.types.file
import guru.nidi.graphviz.engine.Format
import guru.nidi.graphviz.engine.Graphviz
import io.github.apl_cornell.viaduct.backends.CodeGenerationBackend
import io.github.apl_cornell.viaduct.backends.DefaultCombinedBackend
import io.github.apl_cornell.viaduct.passes.compile
import io.github.apl_cornell.viaduct.passes.compileToKotlin
import io.github.apl_cornell.viaduct.selection.SelectionProblemSolver
import io.github.apl_cornell.viaduct.selection.SimpleCostRegime
import io.github.apl_cornell.viaduct.selection.defaultSelectionProblemSolver
import io.github.apl_cornell.viaduct.selection.selectionProblemSolvers
import mu.KotlinLogging
import java.io.File
import java.io.StringWriter
Expand Down Expand Up @@ -76,6 +81,11 @@ class Compile : CliktCommand(help = "Compile ideal protocol to secure distribute
help = "Translate .via source file to a .kt file"
).flag(default = false)

val selectionProblemSolver: SelectionProblemSolver by option(
"--solver",
help = "Pick which solver to use for protocol selection"
).choice(selectionProblemSolvers.toMap()).default(defaultSelectionProblemSolver)

override fun run() {
val costRegime = if (wanCost) SimpleCostRegime.WAN else SimpleCostRegime.LAN

Expand All @@ -85,6 +95,7 @@ class Compile : CliktCommand(help = "Compile ideal protocol to secure distribute
fileName = output?.nameWithoutExtension ?: "Source",
packageName = ".",
backend = CodeGenerationBackend,
selectionSolver = selectionProblemSolver,
costRegime = costRegime,
saveLabelConstraintGraph = constraintGraphOutput::dumpGraph,
saveInferredLabels = labelOutput,
Expand All @@ -96,6 +107,7 @@ class Compile : CliktCommand(help = "Compile ideal protocol to secure distribute
val compiledProgram =
input.sourceFile().compile(
backend = DefaultCombinedBackend,
selectionSolver = selectionProblemSolver,
costRegime = costRegime,
saveLabelConstraintGraph = constraintGraphOutput::dumpGraph,
saveInferredLabels = labelOutput,
Expand Down
14 changes: 13 additions & 1 deletion compiler/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ plugins {
id("org.xbib.gradle.plugin.jflex") version "1.6.0"
}

val gurobiHome: String? = System.getenv("GUROBI_HOME")
val linkGurobi: Boolean = gurobiHome != null

/** Dependencies */

dependencies {
Expand Down Expand Up @@ -39,6 +42,10 @@ dependencies {
// SMT solving
implementation("io.github.tudo-aqua:z3-turnkey:4.8.14")

if (linkGurobi) {
implementation(files("$gurobiHome/lib/gurobi.jar", "$gurobiHome/lib/gurobi-javadoc.jar"))
}

// Testing
testImplementation(project(":test-utilities"))
testImplementation(kotlin("reflect"))
Expand All @@ -54,7 +61,12 @@ val compileCup by tasks.registering(CompileCupTask::class)

sourceSets {
main {
java.srcDir(compileCup.map { it.outputDirectory })
java {
srcDir(compileCup.map { it.outputDirectory })
if (linkGurobi) {
srcDir("src/plugins/gurobi")
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import io.github.apl_cornell.viaduct.codegeneration.compileToKotlin
import io.github.apl_cornell.viaduct.parsing.SourceFile
import io.github.apl_cornell.viaduct.parsing.parse
import io.github.apl_cornell.viaduct.selection.ProtocolSelection
import io.github.apl_cornell.viaduct.selection.SelectionProblemSolver
import io.github.apl_cornell.viaduct.selection.SimpleCostEstimator
import io.github.apl_cornell.viaduct.selection.SimpleCostRegime
import io.github.apl_cornell.viaduct.selection.Z3Selection
import io.github.apl_cornell.viaduct.selection.defaultSelectionProblemSolver
import io.github.apl_cornell.viaduct.selection.validateProtocolAssignment
import io.github.apl_cornell.viaduct.syntax.intermediate.DeclarationNode
import io.github.apl_cornell.viaduct.syntax.intermediate.LetNode
Expand All @@ -34,6 +35,7 @@ private val logger = KotlinLogging.logger("Compile")
/** Similar to [SourceFile.compileToKotlin], but returns a program for the interpreter. */
fun SourceFile.compile(
backend: Backend,
selectionSolver: SelectionProblemSolver = defaultSelectionProblemSolver,
costRegime: SimpleCostRegime = SimpleCostRegime.WAN,
saveLabelConstraintGraph: ((graphWriter: (Writer) -> Unit) -> Unit)? = null,
saveInferredLabels: File? = null,
Expand Down Expand Up @@ -76,7 +78,7 @@ fun SourceFile.compile(
val costEstimator = SimpleCostEstimator(protocolComposer, costRegime)
val protocolAssignment = logger.duration("protocol selection") {
ProtocolSelection(
Z3Selection(),
selectionSolver,
protocolFactory,
protocolComposer,
costEstimator
Expand Down Expand Up @@ -136,6 +138,7 @@ fun SourceFile.compileToKotlin(
fileName: String,
packageName: String,
backend: Backend,
selectionSolver: SelectionProblemSolver = defaultSelectionProblemSolver,
costRegime: SimpleCostRegime = SimpleCostRegime.WAN,
saveLabelConstraintGraph: ((graphWriter: (Writer) -> Unit) -> Unit)? = null,
saveInferredLabels: File? = null,
Expand All @@ -145,6 +148,7 @@ fun SourceFile.compileToKotlin(
val postProcessedProgram =
this.compile(
backend,
selectionSolver,
costRegime,
saveLabelConstraintGraph,
saveInferredLabels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ class ProtocolSelection(
throw NoHostDeclarationsError(program.sourceLocation.sourcePath)
}

val constraintGenerator = SelectionConstraintGenerator(program, protocolFactory, protocolComposer, costEstimator)
val constraintGenerator =
SelectionConstraintGenerator(program, protocolFactory, protocolComposer, costEstimator)
val selectionProblem = constraintGenerator.getSelectionProblem()
return solver.solveSelectionProblem(selectionProblem) ?: throw NoSelectionSolutionError(program)
return solver.solve(selectionProblem) ?: throw NoSelectionSolutionError(program)
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package io.github.apl_cornell.viaduct.selection

interface SelectionProblemSolver {
fun solveSelectionProblem(problem: SelectionProblem): ProtocolAssignment?
fun solve(problem: SelectionProblem): ProtocolAssignment?
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.github.apl_cornell.viaduct.selection

private val gurobiSolver: SelectionProblemSolver? =
try {
val gurobiClass = Class.forName("${PackageName.javaClass.packageName}.GurobiSelectionProblemSolver")
val instance = gurobiClass.getDeclaredField("INSTANCE").get(null) as SelectionProblemSolver
instance
} catch (e: ClassNotFoundException) {
null
}

val defaultSelectionProblemSolver: SelectionProblemSolver =
gurobiSolver ?: Z3SelectionProblemSolver

/** Returns a list name-constructor pairs for all [SelectionProblemSolver] classes. */
val selectionProblemSolvers: List<Pair<String, SelectionProblemSolver>> =
listOfNotNull(
Pair("z3", Z3SelectionProblemSolver),
gurobiSolver?.let { Pair("gurobi", it) }
)

private object PackageName
Original file line number Diff line number Diff line change
@@ -1,189 +0,0 @@
package io.github.apl_cornell.viaduct.selection

import com.microsoft.z3.BoolExpr
import com.microsoft.z3.Context
import com.microsoft.z3.Global
import com.microsoft.z3.IntExpr
import com.microsoft.z3.IntNum
import com.microsoft.z3.Status
import com.microsoft.z3.enumerations.Z3_lbool
import com.uchuhimo.collections.BiMap
import com.uchuhimo.collections.mutableBiMapOf
import io.github.apl_cornell.viaduct.syntax.Protocol
import io.github.apl_cornell.viaduct.util.FreshNameGenerator
import mu.KotlinLogging

private val logger = KotlinLogging.logger("Z3Selection")

/**
* Constraint problem using Z3. Z3 has an optimization module that can return models with minimal cost.
*/
class Z3Selection : SelectionProblemSolver {
companion object {
init {
// Use old arithmetic solver to fix regression introduced in Z3 v4.8.9
Global.setParameter("smt.arith.solver", "2")
}
}

private val nameGenerator = FreshNameGenerator()

/** Convert a SelectionConstraint into a Z3 BoolExpr. **/
private fun boolExpr(
constraint: SelectionConstraint,
ctx: Context,
vmap: BiMap<FunctionVariable, IntExpr>,
boolVarMap: Map<String, BoolExpr>,
protocolMap: BiMap<Protocol, Int>
): BoolExpr {
return when (constraint) {
is True -> ctx.mkTrue()
is False -> ctx.mkFalse()
is HostVariable -> boolVarMap[constraint.variable]!!
is GuardVisibilityFlag -> boolVarMap[constraint.variable]!!
is Literal -> ctx.mkBool(constraint.literalValue)
is Implies ->
ctx.mkImplies(
boolExpr(constraint.lhs, ctx, vmap, boolVarMap, protocolMap),
boolExpr(constraint.rhs, ctx, vmap, boolVarMap, protocolMap)
)

is Or ->
constraint.props.fold(ctx.mkFalse()) { acc, prop ->
ctx.mkOr(acc, boolExpr(prop, ctx, vmap, boolVarMap, protocolMap))
}

is And ->
constraint.props.fold(ctx.mkTrue()) { acc, prop ->
ctx.mkAnd(acc, boolExpr(prop, ctx, vmap, boolVarMap, protocolMap))
}

is Not -> ctx.mkNot(boolExpr(constraint.rhs, ctx, vmap, boolVarMap, protocolMap))
is VariableIn -> ctx.mkEq(vmap[constraint.variable], ctx.mkInt(protocolMap[constraint.protocol]!!))
is VariableEquals -> ctx.mkEq(vmap[constraint.var1], vmap[constraint.var2])
}
}

/** Convert a CostExpression into a Z3 ArithExpr. */
private fun arithExpr(
symCost: SymbolicCost,
ctx: Context,
fvMap: BiMap<FunctionVariable, IntExpr>,
boolVarMap: Map<String, BoolExpr>,
protocolMap: BiMap<Protocol, Int>
): Pair<IntExpr, BoolExpr> =
when (symCost) {
is CostLiteral -> Pair(ctx.mkInt(symCost.cost), ctx.mkTrue())

is CostAdd -> {
val (lhsExpr, constrsL) = arithExpr(symCost.lhs, ctx, fvMap, boolVarMap, protocolMap)
val (rhsExpr, constrsR) = arithExpr(symCost.rhs, ctx, fvMap, boolVarMap, protocolMap)
Pair(ctx.mkAdd(lhsExpr, rhsExpr) as IntExpr, ctx.mkAnd(constrsL, constrsR))
}

is CostMul -> {
val (rhsExpr, constrs) = arithExpr(symCost.rhs, ctx, fvMap, boolVarMap, protocolMap)
Pair(ctx.mkMul(ctx.mkInt(symCost.lhs), rhsExpr) as IntExpr, constrs)
}

is CostMax -> {
val (lhsExpr, constrsL) = arithExpr(symCost.lhs, ctx, fvMap, boolVarMap, protocolMap)
val (rhsExpr, constrsR) = arithExpr(symCost.rhs, ctx, fvMap, boolVarMap, protocolMap)
Pair(
ctx.mkITE(ctx.mkGe(lhsExpr, rhsExpr), lhsExpr, rhsExpr) as IntExpr,
ctx.mkAnd(constrsL, constrsR)
)
}

is CostChoice -> {
val costVarName = this.nameGenerator.getFreshName("cost")
val costVar = ctx.mkFreshConst(costVarName, ctx.intSort) as IntExpr
Pair(
costVar,
ctx.mkAnd(
*symCost.choices.map { choice ->
val guardExpr = boolExpr(choice.first, ctx, fvMap, boolVarMap, protocolMap)
val (costExpr, costConstrs) = arithExpr(choice.second, ctx, fvMap, boolVarMap, protocolMap)
ctx.mkImplies(guardExpr, ctx.mkAnd(ctx.mkEq(costVar, costExpr), costConstrs))
}.toTypedArray()
)
)
}
}

/** Protocol selection. */
override fun solveSelectionProblem(problem: SelectionProblem): ProtocolAssignment? {
Context().use { ctx ->
val constraints = problem.constraints
val programCost = problem.cost

val solver = ctx.mkOptimize()

val protocolMap = mutableBiMapOf<Protocol, Int>()
val fvMap = mutableBiMapOf<FunctionVariable, IntExpr>()
val boolVarMap = mutableMapOf<String, BoolExpr>()

var protocolCounter = 1
for (constraint in constraints) {
for (fv in constraint.functionVariables()) {
if (!fvMap.containsKey(fv)) {
val fvSymname = this.nameGenerator.getFreshName("${fv.function.name}_${fv.variable.name}")
fvMap[fv] = ctx.mkFreshConst(fvSymname, ctx.intSort) as IntExpr
}
}

for (protocol in constraint.protocols()) {
if (!protocolMap.containsKey(protocol)) {
protocolMap[protocol] = protocolCounter
protocolCounter++
}
}

for (variable in constraint.variableNames()) {
if (!boolVarMap.containsKey(variable)) {
val varName = this.nameGenerator.getFreshName(variable)
boolVarMap[variable] = ctx.mkFreshConst(varName, ctx.boolSort) as BoolExpr
}
}
}

if (fvMap.values.isNotEmpty()) {
// load selection constraints into Z3
for (constraint in constraints) {
solver.Add(boolExpr(constraint, ctx, fvMap, boolVarMap, protocolMap))
}

val (costExpr, costConstrs) = arithExpr(programCost, ctx, fvMap, boolVarMap, protocolMap)
solver.Add(costConstrs)
solver.MkMinimize(costExpr)

val symvarCount = fvMap.size + boolVarMap.size

logger.info { "number of symvars: $symvarCount" }

if (solver.Check() == Status.SATISFIABLE) {
val model = solver.model
val assignment =
ProtocolAssignment(
fvMap.mapValues { e ->
val protocolIndex = (model.getConstInterp(e.value) as IntNum).int
protocolMap.inverse.getValue(protocolIndex)
},
boolVarMap.mapValues { kv ->
(model.evaluate(kv.value, false) as BoolExpr).boolValue == Z3_lbool.Z3_L_TRUE
},
problem
)

logger.info { "constraints satisfiable, extracted model" }

return assignment
} else {
return null
}
} else {
return ProtocolAssignment(mapOf(), mapOf(), problem)
}
}
}
}
Loading