Skip to content
This repository has been archived by the owner on Aug 19, 2024. It is now read-only.

Commit

Permalink
Add k-induction for SmtModelCheckers (#713)
Browse files Browse the repository at this point in the history
* K-induction: Initial codepath.

* Add k-induction for SmtModelCheckers

Reused much of the BMC code.
- Perform BMC for cycles [0..k-1]
- Asserted constraints for cycles [n..n + k]
- Asserted assumptions for cycles [n..n + k-1]
- Checked the negation of the assertions for cycle n + k

* K-Induction: Initialize transition system at arbitrary step

In doing so there is no need to filter out the _resetActive constraint.

* K-Induction: Add zipcpu k-induction tests
  • Loading branch information
Gallagator authored Mar 8, 2024
1 parent 06e6ea8 commit bdb84f7
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 80 deletions.
11 changes: 11 additions & 0 deletions src/main/scala/chiseltest/formal/Formal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import firrtl2.transforms.formal.DontAssertSubmoduleAssumptionsAnnotation

sealed trait FormalOp extends NoTargetAnnotation
case class BoundedCheck(kMax: Int = -1) extends FormalOp
case class InductionCheck(kMax: Int = -1) extends FormalOp

/** Specifies how many cycles the circuit should be reset for. */
case class ResetOption(cycles: Int = 1) extends NoTargetAnnotation {
Expand All @@ -27,6 +28,14 @@ private[chiseltest] object FailedBoundedCheckException {
}
}

class FailedInductionCheckException(val message: String, val failAt: Int) extends Exception(message)
private[chiseltest] object FailedInductionCheckException {
def apply(module: String, failAt: Int): FailedInductionCheckException = {
val msg = s"[$module] found an assertion violation after $failAt steps!"
new FailedInductionCheckException(msg, failAt)
}
}

/** Adds the `verify` command for formal checks to a ChiselScalatestTester */
trait Formal { this: HasTestName =>
def verify[T <: Module](dutGen: => T, annos: AnnotationSeq, chiselAnnos: firrtl.AnnotationSeq = Seq()): Unit = {
Expand Down Expand Up @@ -79,5 +88,7 @@ private object Formal {
def executeOp(state: CircuitState, resetLength: Int, op: FormalOp): Unit = op match {
case BoundedCheck(kMax) =>
backends.Maltese.bmc(state.circuit, state.annotations, kMax = kMax, resetLength = resetLength)
case InductionCheck(kMax) =>
backends.Maltese.induction(state.circuit, state.annotations, kMax = kMax, resetLength = resetLength)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ private[chiseltest] case class ModelCheckSuccess() extends ModelCheckResult { ov
private[chiseltest] case class ModelCheckFail(witness: Witness) extends ModelCheckResult {
override def isFail: Boolean = true
}
private[chiseltest] case class ModelCheckFailInduction(witness: Witness) extends ModelCheckResult {
override def isFail: Boolean = true
}

private[chiseltest] trait IsModelChecker {
def name: String
val prefix: String
val fileExtension: String
def check(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult
def checkBounded(sys: TransitionSystem, kMax: Int = -1): ModelCheckResult
def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult
}

private[chiseltest] case class Witness(
Expand Down
65 changes: 54 additions & 11 deletions src/main/scala/chiseltest/formal/backends/Maltese.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@ package chiseltest.formal.backends

import chiseltest.formal.backends.btor.BtormcModelChecker
import chiseltest.formal.backends.smt._
import chiseltest.formal.{DoNotModelUndef, DoNotOptimizeFormal, FailedBoundedCheckException}
import chiseltest.formal.{
DoNotModelUndef,
DoNotOptimizeFormal,
FailedBoundedCheckException,
FailedInductionCheckException
}
import firrtl2._
import firrtl2.annotations._
import firrtl2.stage._
import firrtl2.backends.experimental.smt.random._
import firrtl2.backends.experimental.smt._
import chiseltest.simulator._
import firrtl2.options.Dependency
import os.Path

sealed trait FormalEngineAnnotation extends NoTargetAnnotation

Expand Down Expand Up @@ -62,6 +68,26 @@ private[chiseltest] object Maltese {
require(kMax > 0)
require(resetLength >= 0)

val checkFn = (checker: IsModelChecker, sys: TransitionSystem) =>
checker.checkBounded(sys, kMax = kMax + resetLength);
check(circuit, annos, checkFn, resetLength);
}

def induction(circuit: ir.Circuit, annos: AnnotationSeq, kMax: Int, resetLength: Int = 0): Unit = {
require(kMax > 0)
require(resetLength >= 0)

val checkFn = (checker: IsModelChecker, sys: TransitionSystem) =>
checker.checkInduction(sys, resetLength, kMax = kMax);
check(circuit, annos, checkFn, resetLength);
}

def check(
circuit: ir.Circuit,
annos: AnnotationSeq,
checkFn: (IsModelChecker, TransitionSystem) => ModelCheckResult,
resetLength: Int
): Unit = {
// convert to transition system
val targetDir = Compiler.requireTargetDir(annos)
val modelUndef = !annos.contains(DoNotModelUndef)
Expand All @@ -77,19 +103,15 @@ private[chiseltest] object Maltese {
// perform check
val checkers = makeCheckers(annos, targetDir)
assert(checkers.size == 1, "Parallel checking not supported atm!")
checkers.head.check(sysInfo.sys, kMax = kMax + resetLength) match {
checkFn(checkers.head, sysInfo.sys) match {
case ModelCheckFail(witness) =>
val writeVcd = annos.contains(WriteVcdAnnotation)
if (writeVcd) {
val sim = new TransitionSystemSimulator(sysInfo.sys)
sim.run(witness, vcdFileName = Some((targetDir / s"${circuit.main}.bmc.vcd").toString))
val trace = witnessToTrace(sysInfo, witness)
val treadleState = prepTreadle(circuit, annos, modelUndef)
val treadleDut = TreadleBackendAnnotation.getSimulator.createContext(treadleState)
Trace.replayOnSim(trace, treadleDut)
}
processWitness(circuit, sysInfo, annos, witness, modelUndef, targetDir, "bmc")
val failSteps = witness.inputs.length - 1 - resetLength
throw FailedBoundedCheckException(circuit.main, failSteps)
case ModelCheckFailInduction(witness) =>
processWitness(circuit, sysInfo, annos, witness, modelUndef, targetDir, "induction")
val failSteps = witness.inputs.length - 1
throw FailedInductionCheckException(circuit.main, failSteps)
case ModelCheckSuccess() => // good!
}
}
Expand All @@ -110,6 +132,27 @@ private[chiseltest] object Maltese {
}
}

// Produces a vcd file based on the witness is @annos contains WriteVcdAnnotation
private def processWitness(
circuit: ir.Circuit,
sysInfo: SysInfo,
annos: AnnotationSeq,
witness: Witness,
modelUndef: Boolean,
targetDir: Path,
vcdSuffix: String
) = {
val writeVcd = annos.contains(WriteVcdAnnotation)
if (writeVcd) {
val sim = new TransitionSystemSimulator(sysInfo.sys)
sim.run(witness, vcdFileName = Some((targetDir / s"${circuit.main}.${vcdSuffix}.vcd").toString))
val trace = witnessToTrace(sysInfo, witness)
val treadleState = prepTreadle(circuit, annos, modelUndef)
val treadleDut = TreadleBackendAnnotation.getSimulator.createContext(treadleState)
Trace.replayOnSim(trace, treadleDut)
}
}

private val LoweringAnnos: AnnotationSeq = Seq(
// we need to flatten the whole circuit
RunFirrtlTransformAnnotation(Dependency(FlattenPass)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ class BtormcModelChecker(targetDir: os.Path) extends IsModelChecker {
override val name: String = "btormc"
override val prefix: String = "btormc"

override def check(sys: TransitionSystem, kMax: Int): ModelCheckResult = {
override def checkInduction(sys: TransitionSystem, resetLenght: Int, kMax: Int = -1): ModelCheckResult = {
throw new RuntimeException(s"Induction unsupported for btormc");
}

override def checkBounded(sys: TransitionSystem, kMax: Int): ModelCheckResult = {
// serialize the system to btor2
val filename = sys.name + ".btor"
// btromc isn't happy if we include output nodes, so we skip them during serialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ class CompactSmtEncoding(sys: TransitionSystem) extends TransitionSystemSmtEncod
s
}

def init(ctx: SolverContext): Unit = {
def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit = {
assert(states.isEmpty)
val s0 = appendState(ctx)
ctx.assert(BVFunctionCall(stateInitFun, List(s0), 1))
if (!isArbitraryStep) {
ctx.assert(BVFunctionCall(stateInitFun, List(s0), 1))
}
}

def unroll(ctx: SolverContext): Unit = {
Expand Down
169 changes: 111 additions & 58 deletions src/main/scala/chiseltest/formal/backends/smt/SMTModelChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,80 @@ class SMTModelChecker(
override val prefix: String = solver.name
override val fileExtension: String = ".smt2"

override def check(
override def checkInduction(sys: TransitionSystem, resetLength: Int, kMax: Int = -1): ModelCheckResult = {
require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax")
// Check BMC first
checkBounded(sys, kMax + resetLength) match {
case ModelCheckFail(w) => return ModelCheckFail(w)
case _ =>
}

val (ctx, enc) = checkInit(sys)
// Initialise transition system at an arbitrary step
enc.init(ctx, true)

val constraints = sys.signals.filter(s => s.lbl == IsConstraint).map(_.name)
val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name)

(0 to kMax).foreach { k =>
// Assume all constraints hold for each k
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))
// Assume All assertions up to k
assertions.foreach(c => ctx.assert(enc.getAssertion(c)))
// Advance
enc.unroll(ctx)
}
// Assume constraints one last time
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))

val modelResult =
checkAssertions(sys, ctx, enc, assertions, kMax).map(ModelCheckFailInduction(_)).getOrElse(ModelCheckSuccess())
checkFini(ctx)
modelResult
}

override def checkBounded(
sys: TransitionSystem,
kMax: Int
): ModelCheckResult = {
require(kMax > 0 && kMax <= 2000, s"unreasonable kMax=$kMax")

val (ctx, enc) = checkInit(sys)
// Initialise transition system at reset
enc.init(ctx, false)

val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name)
val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name)

(0 to kMax).foreach { k =>
if (printProgress) println(s"Step #$k")

// assume all constraints hold in this step
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))

// make sure the constraints are not contradictory
if (options.checkConstraints) {
val res = ctx.check(produceModel = false)
assert(res.isSat, s"Found unsatisfiable constraints in cycle $k")
}

checkAssertions(sys, ctx, enc, assertions, k) match {
case Some(w) => {
checkFini(ctx)
return ModelCheckFail(w)
}
case _ => {}
}
// advance
enc.unroll(ctx)
}

checkFini(ctx)
ModelCheckSuccess()
}

// Initialise solver context and transition system
private def checkInit(sys: TransitionSystem): (SolverContext, TransitionSystemSmtEncoding) = {
val ctx = solver.createContext()
// z3 only supports the non-standard as-const array syntax when the logic is set to ALL
val logic = if (solver.name.contains("z3")) { "ALL" }
Expand All @@ -47,73 +115,58 @@ class SMTModelChecker(
new UnrollSmtEncoding(sys)
}
enc.defineHeader(ctx)
enc.init(ctx)

val constraints = sys.signals.filter(_.lbl == IsConstraint).map(_.name)
val assertions = sys.signals.filter(_.lbl == IsBad).map(_.name)
(ctx, enc)
}

(0 to kMax).foreach { k =>
if (printProgress) println(s"Step #$k")
private def checkFini(ctx: SolverContext) = {
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
}

// assume all constraints hold in this step
constraints.foreach(c => ctx.assert(enc.getConstraint(c)))
private def checkAssertions(
sys: TransitionSystem,
ctx: SolverContext,
enc: TransitionSystemSmtEncoding,
assertions: List[String],
k: Int
): Option[Witness] = {
if (options.checkBadStatesIndividually) {
// check each bad state individually
assertions.zipWithIndex.foreach { case (b, bi) =>
if (printProgress) print(s"- b$bi? ")

// make sure the constraints are not contradictory
if (options.checkConstraints) {
val res = ctx.check(produceModel = false)
assert(res.isSat, s"Found unsatisfiable constraints in cycle $k")
}

if (options.checkBadStatesIndividually) {
// check each bad state individually
assertions.zipWithIndex.foreach { case (b, bi) =>
if (printProgress) print(s"- b$bi? ")

ctx.push()
ctx.assert(BVNot(enc.getAssertion(b)))
val res = ctx.check(produceModel = false)

// did we find an assignment for which the bad state is true?
if (res.isSat) {
if (printProgress) println("")
val w = getWitness(ctx, sys, enc, k, Seq(b))
ctx.pop()
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
return ModelCheckFail(w)
} else {
if (printProgress) println("")
}
ctx.pop()
}
} else {
val anyBad = BVNot(BVAnd(assertions.map(enc.getAssertion)))
ctx.push()
ctx.assert(anyBad)
ctx.assert(BVNot(enc.getAssertion(b)))
val res = ctx.check(produceModel = false)

// did we find an assignment for which at least one bad state is true?
// did we find an assignment for which the bad state is true?
if (res.isSat) {
val w = getWitness(ctx, sys, enc, k)
ctx.pop()
if (printProgress) println("")
val w = getWitness(ctx, sys, enc, k, Seq(b))
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
return ModelCheckFail(w)
return Some(w)
} else {
if (printProgress) println("")
}
ctx.pop()
}

// advance
enc.unroll(ctx)
} else {
val anyBad = BVNot(BVAnd(assertions.map(enc.getAssertion)))
ctx.push()
ctx.assert(anyBad)
val res = ctx.check(produceModel = false)

// did we find an assignment for which at least one bad state is true?
if (res.isSat) {
val w = getWitness(ctx, sys, enc, k)
ctx.pop()
return Some(w)
}
ctx.pop()
}

// clean up
ctx.pop()
assert(ctx.stackDepth == 0, s"Expected solver stack to be empty, not: ${ctx.stackDepth}")
ctx.close()
ModelCheckSuccess()
None
}

private def getWitness(
Expand Down Expand Up @@ -151,10 +204,10 @@ class SMTModelChecker(

trait TransitionSystemSmtEncoding {
def defineHeader(ctx: SolverContext): Unit
def init(ctx: SolverContext): Unit
def init(ctx: SolverContext, isArbitraryStep: Boolean): Unit
def unroll(ctx: SolverContext): Unit
def getConstraint(name: String): BVExpr
def getAssertion(name: String): BVExpr
def getSignalAt(sym: BVSymbol, k: Int): BVExpr
def getSignalAt(sym: ArraySymbol, k: Int): ArrayExpr
def getSignalAt(sym: BVSymbol, k: Int): BVExpr
def getSignalAt(sym: ArraySymbol, k: Int): ArrayExpr
}
Loading

0 comments on commit bdb84f7

Please sign in to comment.