Skip to content

Commit

Permalink
More aggressive term simplification and some additional shortcuts
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoeilers committed Oct 27, 2024
1 parent e51ae4d commit acd55fb
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 27 deletions.
9 changes: 9 additions & 0 deletions src/main/scala/decider/TermToZ3APIConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ class TermToZ3APIConverter
case IntPermTimes(t0, t1) => ctx.mkMul(convertTerm(t0).asInstanceOf[ArithExpr], convertTerm(t1).asInstanceOf[ArithExpr])
case PermIntDiv(t0, t1) => ctx.mkDiv(convertToReal(t0), convertToReal(t1))
case PermPermDiv(t0, t1) => ctx.mkDiv(convertToReal(t0), convertToReal(t1))
case PermMax(t0, t1) => {
/*
(define-fun $Perm.min ((p1 $Perm) (p2 $Perm)) Real
(ite (<= p1 p2) p1 p2))
*/
val e0 = convert(t0).asInstanceOf[ArithExpr]
val e1 = convert(t1).asInstanceOf[ArithExpr]
ctx.mkITE(ctx.mkLe(e0, e1), e1, e0)
}
case PermMin(t0, t1) => {
/*
(define-fun $Perm.min ((p1 $Perm) (p2 $Perm)) Real
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/interfaces/state/Chunks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ package viper.silicon.interfaces.state
import viper.silicon.resources.ResourceID
import viper.silicon.state.terms.{Term, Var}

trait Chunk
trait Chunk {
def addEquality(t1: Term, t2: Term): Chunk = this
}

trait ChunkIdentifer

Expand Down
20 changes: 16 additions & 4 deletions src/main/scala/rules/Brancher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import viper.silicon.decider.PathConditionStack
import viper.silicon.interfaces.{Unreachable, VerificationResult}
import viper.silicon.reporting.condenseToViperResult
import viper.silicon.state.State
import viper.silicon.state.terms.{FunctionDecl, MacroDecl, Not, Term}
import viper.silicon.state.terms.{BuiltinEquals, FunctionDecl, Literal, MacroDecl, Not, Term, Var}
import viper.silicon.verifier.Verifier
import viper.silver.ast
import viper.silver.reporter.{BranchFailureMessage}
import viper.silver.reporter.BranchFailureMessage
import viper.silver.verifier.Failure

trait BranchingRules extends SymbolicExecutionRules {
Expand Down Expand Up @@ -146,7 +146,13 @@ object brancher extends BranchingRules {
if (v.uniqueId != v0.uniqueId)
v1.decider.prover.saturate(Verifier.config.proverSaturationTimeouts.afterContract)

val result = fElse(v1.stateConsolidator(s1).consolidateOptionally(s1, v1), v1)
val s1p = condition match {
case Not(BuiltinEquals(p0, p1)) if (p0.isInstanceOf[Var] || p0.isInstanceOf[Literal]) && (p1.isInstanceOf[Var] || p1.isInstanceOf[Literal]) =>
s1.addEquality(p0, p1)
case _ => s1
}

val result = fElse(v1.stateConsolidator(s1p).consolidateOptionally(s1p, v1), v1)
if (wasElseExecutedOnDifferentVerifier) {
v1.decider.resetProverOptions()
v1.decider.setProverOptions(proverArgsOfElseBranchDecider)
Expand Down Expand Up @@ -185,8 +191,14 @@ object brancher extends BranchingRules {
executionFlowController.locally(s, v)((s1, v1) => {
v1.decider.prover.comment(s"[then-branch: $cnt | $condition]")
v1.decider.setCurrentBranchCondition(condition, conditionExp)
val s1p = condition match {
case BuiltinEquals(p0, p1) if (p0.isInstanceOf[Var] || p0.isInstanceOf[Literal]) && (p1.isInstanceOf[Var] || p1.isInstanceOf[Literal]) =>
s1.addEquality(p0, p1)
case _ => s1
}
val s1pp = v1.stateConsolidator(s1p).consolidateOptionally(s1p, v1)

fThen(v1.stateConsolidator(s1).consolidateOptionally(s1, v1), v1)
fThen(s1pp, v1)
})
} else {
Unreachable()
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/rules/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import viper.silicon.state.terms._
import viper.silicon.state.terms.predef.`?r`
import viper.silicon.utils.freshSnap
import viper.silicon.verifier.Verifier
import viper.silver.ast.AnnotationInfo
import viper.silver.cfg.{ConditionalEdge, StatementBlock}

trait ExecutionRules extends SymbolicExecutionRules {
Expand Down Expand Up @@ -748,6 +749,7 @@ object executor extends ExecutionRules {
}

private def ssaifyRhs(rhs: Term, name: String, typ: ast.Type, v: Verifier): Term = {
return rhs
rhs match {
case _: Var | _: Literal =>
rhs
Expand Down
39 changes: 26 additions & 13 deletions src/main/scala/rules/chunks/MoreCompleteExhaleSupporter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package viper.silicon.rules.chunks

import viper.silicon.rules.chunks.chunkSupporter.findChunksWithID
import viper.silicon.interfaces.state._
import viper.silicon.interfaces.{Success, VerificationResult}
import viper.silicon.interfaces.{Success, Unreachable, VerificationResult}
import viper.silicon.resources.{FieldID, NonQuantifiedPropertyInterpreter, Resources}
import viper.silicon.rules.{Complete, ConsumptionResult, ConsumptionRules, Incomplete, SnapshotMapDefinition, SymbolicExecutionRules, magicWandSupporter}
import viper.silicon.state._
Expand Down Expand Up @@ -485,20 +485,26 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules {
val consumedChunks = ListBuffer[NonQuantifiedChunk]()
var moreNeeded = true

val definiteAlias = chunkSupporter.findChunk[NonQuantifiedChunk](relevantChunks, id, args, v).filter(c =>
v.decider.check(IsPositive(c.perm), Verifier.config.checkTimeout())
)
val (sortedChunks, checkedDefiniteAlias) = if (relevantChunks.size < 2) {
(relevantChunks, None)
} else {
val definiteAlias = chunkSupporter.findChunk[NonQuantifiedChunk](relevantChunks, id, args, v).filter(c =>
v.decider.check(IsPositive(c.perm), Verifier.config.checkTimeout())
)

val sortFunction: (NonQuantifiedChunk, NonQuantifiedChunk) => Boolean = (ch1, ch2) => {
// The definitive alias and syntactic aliases should get priority, since it is always
// possible to consume from them
definiteAlias.contains(ch1) || !definiteAlias.contains(ch2) && ch1.args == args
val sortFunction: (NonQuantifiedChunk, NonQuantifiedChunk) => Boolean = (ch1, ch2) => {
// The definitive alias and syntactic aliases should get priority, since it is always
// possible to consume from them
definiteAlias.contains(ch1) || !definiteAlias.contains(ch2) && ch1.args == args
}

(relevantChunks.sortWith(sortFunction), Some(definiteAlias))
}

val additionalArgs = s.relevantQuantifiedVariables
var currentFunctionRecorder = s.functionRecorder

relevantChunks.sortWith(sortFunction) foreach { ch =>
sortedChunks foreach { ch =>
if (moreNeeded) {
val eq = And(ch.args.zip(args).map { case (t1, t2) => t1 === t2 })

Expand Down Expand Up @@ -526,15 +532,20 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules {
pNeeded = PermMinus(pNeeded, pTaken)
val consumedChunk = ch.withPerm(pTaken)

if (!v.decider.check(pTaken=== NoPerm, Verifier.config.splitTimeout())) {
val noneTaken = pTaken=== NoPerm

if (noneTaken == False || !v.decider.check(noneTaken, Verifier.config.splitTimeout())) {
consumedChunks.append(consumedChunk)
}

if (!v.decider.check(IsNonPositive(newChunk.perm), Verifier.config.splitTimeout())) {
val newChunkHasNoPerm = IsNonPositive(newChunk.perm)

if (newChunkHasNoPerm == False || !v.decider.check(newChunkHasNoPerm, Verifier.config.splitTimeout())) {
newChunks.append(newChunk)
}

moreNeeded = !v.decider.check(pNeeded === NoPerm, Verifier.config.splitTimeout())
val noMoreNeeded = pNeeded === NoPerm
moreNeeded = noMoreNeeded == False || !v.decider.check(noMoreNeeded, Verifier.config.splitTimeout())
} else {
newChunks.append(ch)
}
Expand All @@ -553,7 +564,9 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules {

val s0 = s.copy(functionRecorder = currentFunctionRecorder)

summarise(s0, relevantChunks.toSeq, resource, args, Some(definiteAlias.map(_.snap)), v)((s1, snap, _, _, v1) => {
val checkedDefiniteValue = checkedDefiniteAlias.map(_.map(_.snap))

summarise(s0, relevantChunks.toSeq, resource, args, checkedDefiniteValue, v)((s1, snap, _, _, v1) => {
val condSnap = if (v1.decider.check(IsPositive(perms), Verifier.config.checkTimeout())) {
snap
} else {
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/state/Chunks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ case class BasicChunk(resourceID: BaseID,
case FieldID => s"${args.head}.$id -> $snap # $perm"
case PredicateID => s"$id($snap; ${args.mkString(",")}) # $perm"
}
override def addEquality(t1: Term, t2: Term) = {
BasicChunk(resourceID, id, args.map(_.replace(t1, t2)), snap.replace(t1, t2), perm.replace(t1, t2))
}
}

sealed trait QuantifiedBasicChunk extends QuantifiedChunk {
Expand Down
6 changes: 6 additions & 0 deletions src/main/scala/state/Heap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
package viper.silicon.state

import viper.silicon.interfaces.state.Chunk
import viper.silicon.state.terms.Term

trait Heap {
def values: Iterable[Chunk]
def +(chunk: Chunk): Heap
def +(other: Heap): Heap
def -(chunk: Chunk): Heap
def addEquality(t1: Term, t2: Term): Heap
}

trait HeapFactory[H <: Heap] {
Expand All @@ -38,4 +40,8 @@ final class ListBackedHeap private[state] (chunks: Vector[Chunk])

new ListBackedHeap(prefix ++ suffix.tail)
}
def addEquality(t1: Term, t2: Term) = {
val newChunks = chunks.map(_.addEquality(t1, t2))
new ListBackedHeap(newChunks)
}
}
9 changes: 9 additions & 0 deletions src/main/scala/state/State.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ final case class State(g: Store = Store(),

def cycles(m: ast.Member) = visited.count(_ == m)

def addEquality(t1: Term, t2: Term): State = {
if (t1 == t2) {
this
} else {
val newState = copy(g = g.addEquality(t1, t2), h = h.addEquality(t1, t2))
newState
}
}

def setConstrainable(arps: Iterable[Var], constrainable: Boolean) = {
val newConstrainableARPs =
if (constrainable) constrainableARPs ++ arps
Expand Down
5 changes: 5 additions & 0 deletions src/main/scala/state/Store.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ trait Store {
def get(key: ast.AbstractLocalVar): Option[Term]
def +(kv: (ast.AbstractLocalVar, Term)): Store
def +(other: Store): Store
def addEquality(t1: Term, t2: Term): Store
}

trait StoreFactory[ST <: Store] {
Expand All @@ -41,4 +42,8 @@ final class MapBackedStore private[state] (map: Map[ast.AbstractLocalVar, Term])
def get(key: ast.AbstractLocalVar) = map.get(key)
def +(entry: (ast.AbstractLocalVar, Term)) = new MapBackedStore(map + entry)
def +(other: Store) = new MapBackedStore(map ++ other.values)
def addEquality(t1: Term, t2: Term) = {
val newMap = map.map { case (k, v) => (k, v.replace(t1, t2)) }
new MapBackedStore(newMap)
}
}
81 changes: 72 additions & 9 deletions src/main/scala/state/Terms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1066,19 +1066,47 @@ object Ite extends CondFlyweightTermFactory[(Term, Term, Term), Ite] {
case (False, _, e2) => e2
case (e0, True, False) => e0
case (e0, False, True) => Not(e0)
case (c, e1, e2) =>
val eqs = getEqualities(c)

if (eqs.nonEmpty) {
val eqMap: scala.collection.immutable.Map[Term, Term] = eqs.map(eq => eq.p0 -> eq.p1).toMap
val eqFalseMap: scala.collection.immutable.Map[Term, Term] = eqs.flatMap(eq => {
Seq(eq -> False, eq.flip() -> False)
}).toMap
createIfNonExistent(c, replace(e1, eqMap), replace(e2, eqFalseMap))
} else {
createIfNonExistent(v0)
}
case _ => createIfNonExistent(v0)
}

def replace(t: Term, rps: scala.collection.immutable.Map[Term, Term]): Term = {
assert(rps.nonEmpty)
t.transform {
case trm if rps.contains(trm) => rps(trm)
}()
}

def getEqualities(t: Term): Seq[Equals] = t match {
case eq@Equals(_, _) => Seq(eq)
case And(ts) => ts.flatMap(getEqualities)
case _ => Seq()
}

override def actualCreate(args: (Term, Term, Term)): Ite = new Ite(args._1, args._2, args._3)
}

/* Comparison expression terms */

sealed trait ComparisonTerm extends BooleanTerm

sealed trait Equals extends ComparisonTerm with BinaryOp[Term] { override val op = "==" }
sealed trait Equals extends ComparisonTerm with BinaryOp[Term] {
override val op = "=="
def flip(): Equals
}

object Equals extends ((Term, Term) => BooleanTerm) {
object Equals extends ((Term, Term) => Term) {
def apply(e0: Term, e1: Term) = {
assert(e0.sort == e1.sort,
s"Expected both operands to be of the same sort, but found ${e0.sort} ($e0) and ${e1.sort} ($e1).")
Expand Down Expand Up @@ -1120,14 +1148,20 @@ object Equals extends ((Term, Term) => BooleanTerm) {
}

/* Represents built-in equality, e.g., '=' in SMT-LIB */
class BuiltinEquals private[terms] (val p0: Term, val p1: Term) extends ConditionalFlyweightBinaryOp[BuiltinEquals] with Equals
class BuiltinEquals private[terms] (val p0: Term, val p1: Term) extends ConditionalFlyweightBinaryOp[BuiltinEquals] with Equals {
override def flip() = BuiltinEquals.createIfNonExistent(p1, p0)
}

object BuiltinEquals extends CondFlyweightFactory[(Term, Term), BooleanTerm, BuiltinEquals] {
object BuiltinEquals extends CondFlyweightTermFactory[(Term, Term), BuiltinEquals] {
override def apply(v0: (Term, Term)) = v0 match {
case (p0, p1) if p0 == p1 => True // ME: trying stuff out.
case (p0, Ite(c, t1, t2)) => Ite(c, BuiltinEquals(p0, t1), BuiltinEquals(p0, t2))
case (Ite(c, t1, t2), p0) => Ite(c, BuiltinEquals(t1, p0), BuiltinEquals(t2, p0))
case (p0: PermLiteral, p1: PermLiteral) =>
// NOTE: The else-case (False) is only justified because permission literals are stored in a normal form
// such that two literals are semantically equivalent iff they are syntactically equivalent.
if (p0.literal == p1.literal) True else False
case (p0: Var, p1: Var) if p0 == p1 => True
case _ => createIfNonExistent(v0)
}

Expand All @@ -1136,7 +1170,7 @@ object BuiltinEquals extends CondFlyweightFactory[(Term, Term), BooleanTerm, Bui

/* Custom equality that (potentially) needs to be axiomatised. */
class CustomEquals private[terms] (val p0: Term, val p1: Term) extends ConditionalFlyweightBinaryOp[CustomEquals] with Equals {

override def flip() = CustomEquals.createIfNonExistent(p1, p0)
override val op = "==="
}

Expand Down Expand Up @@ -1445,6 +1479,11 @@ object PermPlus extends CondFlyweightTermFactory[(Term, Term), PermPlus] {
case (FractionPerm(n1, d1), FractionPerm(n2, d2)) if d1 == d2 => FractionPerm(Plus(n1, n2), d1)
case (PermMinus(t00, t01), t1) if t01 == t1 => t00
case (t0, PermMinus(t10, t11)) if t11 == t0 => t10
case (Ite(c, t1, t2), t3) => Ite(c, PermPlus(t1, t3), PermPlus(t2, t3))
case (t1, Ite(c, t2, t3)) => Ite(c, PermPlus(t1, t2), PermPlus(t1, t3))
case (PermMin(t0, t1), t2) => PermMin(PermPlus(t0, t2), PermPlus(t1, t2))
case (PermMax(t0, t1), t2) => PermMax(PermPlus(t0, t2), PermPlus(t1, t2))
case (t0, PermMax(t1, t2)) => PermMax(PermPlus(t0, t1), PermPlus(t0, t2))

case (_, _) => createIfNonExistent(v0)
}
Expand All @@ -1470,9 +1509,22 @@ object PermMinus extends CondFlyweightTermFactory[(Term, Term), PermMinus] {
case (t0, NoPerm) => t0
case (p0, p1) if p0 == p1 => NoPerm
case (p0: PermLiteral, p1: PermLiteral) => FractionPermLiteral(p0.literal - p1.literal)
case (p0, PermMinus(p1, p2)) if p0 == p1 => p2
case (p0, PermMinus(p1, p2)) =>
if (p0 == p1) {
p2
} else {
PermPlus(PermMinus(p0, p1), p2)
}
case (PermMinus(t0, t1), t2) => PermMinus(t0, PermPlus(t1, t2))
case (PermPlus(p0, p1), p2) if p0 == p2 => p1
case (PermPlus(p0, p1), p2) if p1 == p2 => p0
case (Ite(c, t1, t2), t3) => Ite(c, PermMinus(t1, t3), PermMinus(t2, t3))
case (t1, Ite(c, t2, t3)) => Ite(c, PermMinus(t1, t2), PermMinus(t1, t3))
case (PermMin(p0, p1), p2) => PermMin(PermMinus(p0, p2), PermMinus(p1, p2))
case (t0, PermMin(t1, t2)) => PermMax(PermMinus(t0, t1), PermMinus(t0, t2))
case (PermMax(p0, p1), p2) => PermMax(PermMinus(p0, p2), PermMinus(p1, p2))
case (t0, PermMax(t1, t2)) => PermMin(PermMinus(t0, t1), PermMinus(t0, t2))

case (_, _) => createIfNonExistent(v0)
}

Expand All @@ -1496,10 +1548,9 @@ object PermLess extends CondFlyweightTermFactory[(Term, Term), PermLess] {
case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal < p1.literal) True else False

case (t0, Ite(tCond, tIf, tElse)) =>
/* The pattern p0 < b ? p1 : p2 arises very often in the context of quantified permissions.
* Pushing the comparisons into the ite allows further simplifications.
*/
Ite(tCond, PermLess(t0, tIf), PermLess(t0, tElse))
case (Ite(tCond, tIf, tElse), t0) =>
Ite(tCond, PermLess(tIf, t0), PermLess(tElse, t0))

case _ => createIfNonExistent(v0)
}
Expand All @@ -1518,6 +1569,10 @@ object PermAtMost extends CondFlyweightTermFactory[(Term, Term), PermAtMost] {
override def apply(v0: (Term, Term)) = v0 match {
case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal <= p1.literal) True else False
case (t0, t1) if t0 == t1 => True
case (t0, Ite(tCond, tIf, tElse)) =>
Ite(tCond, PermAtMost(t0, tIf), PermAtMost(t0, tElse))
case (Ite(tCond, tIf, tElse), t0) =>
Ite(tCond, PermAtMost(tIf, t0), PermAtMost(tElse, t0))
case _ => createIfNonExistent(v0)
}

Expand All @@ -1538,6 +1593,10 @@ object PermMin extends CondFlyweightTermFactory[(Term, Term), PermMin] {
override def apply(v0: (Term, Term)) = v0 match {
case (t0, t1) if t0 == t1 => t0
case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal > p1.literal) p1 else p0
case (t0, Ite(tCond, tIf, tElse)) =>
Ite(tCond, PermMin(t0, tIf), PermMin(t0, tElse))
case (Ite(tCond, tIf, tElse), t0) =>
Ite(tCond, PermMin(tIf, t0), PermMin(tElse, t0))
case _ => createIfNonExistent(v0)
}

Expand All @@ -1558,6 +1617,10 @@ object PermMax extends CondFlyweightTermFactory[(Term, Term), PermMax] {
override def apply(v0: (Term, Term)) = v0 match {
case (t0, t1) if t0 == t1 => t0
case (p0: PermLiteral, p1: PermLiteral) => if (p0.literal < p1.literal) p1 else p0
case (t0, Ite(tCond, tIf, tElse)) =>
Ite(tCond, PermMax(t0, tIf), PermMax(t0, tElse))
case (Ite(tCond, tIf, tElse), t0) =>
Ite(tCond, PermMax(tIf, t0), PermMax(tElse, t0))
case _ => createIfNonExistent(v0)
}

Expand Down

0 comments on commit acd55fb

Please sign in to comment.