Skip to content

Commit

Permalink
Add way to determine sign (+/-)
Browse files Browse the repository at this point in the history
  • Loading branch information
sakehl committed Dec 4, 2024
1 parent dc107f0 commit cf5b6de
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 120 deletions.
177 changes: 118 additions & 59 deletions src/col/vct/col/ast/util/ExpressionEqualityCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ object ExpressionEqualityCheck {
def equalExpressions[G](lhs: Expr[G], rhs: Expr[G]): Boolean = {
ExpressionEqualityCheck().equalExpressions(lhs, rhs)
}

trait Sign
case class Pos() extends Sign
case class Neg() extends Sign
}

case class InconsistentVariableEquality(v: Local[_], x: BigInt, y: BigInt)
Expand All @@ -31,6 +35,8 @@ case class InconsistentVariableEquality(v: Local[_], x: BigInt, y: BigInt)
}

class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) {
import ExpressionEqualityCheck._

var replacerDepth = 0
var replacerDepthInt = 0
val max_depth = 100
Expand Down Expand Up @@ -255,9 +261,9 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) {
return Some(b1 * b2)
}
}
// THe other cases are to complicated, so we do not consider them
case Mod(e1, e2) if isLower => return Some(0)
case Mod(e1, e2) => isConstantInt(e2)
// The other cases are to complicated, so we do not consider them
case _ =>
}

Expand Down Expand Up @@ -291,12 +297,62 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) {
None
}

def isNonZero(e: Expr[G]): Boolean =
def isNonZero(e: Expr[G]): Option[Boolean] = {
e match {
case v: Local[G] => info.exists(_.variableNotZero.contains(v))
case _ => isConstantInt(e).getOrElse(0) != 0
case _ => lessThenEq(const(1)(e.o), e).getOrElse(false)
case v: Local[G] if info.exists(_.variableNotZero.contains(v)) => return Some(true)
case _ =>
}
isConstantInt(e).map(i => i != 0) orElse
upperBound(e).flatMap(i => if(i<0) Some(true) else None) orElse
lowerBound(e).flatMap(i => if(i>0) Some(true) else None) orElse
lessThenEq(const(1)(e.o), e) orElse
lessThenEq(e, const(-1)(e.o))
}

def isSameSign(e1: Expr[G], e2: Expr[G]): Option[Boolean] = {
// Try to gets signs
(getSign(e1), getSign(e2)) match {
case (Some(s1), Some(s2)) => return Some(s1 == s2)
case _ =>
}

// Determine equal parts of a multiplication
val (rest_e1, rest_e2) = removeEqExprs(
unfoldComm[Mult[G]](e1), unfoldComm[Mult[G]](e2)
)

if(rest_e1.isEmpty && rest_e2.isEmpty) return Some(true)

// Check polarity of rest terms. A negative pol changes the sign.
val polarity1 = rest_e1.map(getSign).foldLeft(true)({
case (_, None) => return None
case (p, Some(Pos())) => p
case (p, Some(Neg())) => !p
})

val polarity2 = rest_e2.map(getSign).foldLeft(true)({
case (_, None) => return None
case (p, Some(Pos())) => p
case (p, Some(Neg())) => !p
})

Some(polarity1 == polarity2)
}

def isPos(b: Boolean): Sign =
if(b) Pos() else Neg()

def isPos(s: Sign): Boolean = s match {
case Pos() => true
case Neg() => false
}

def getSign(e: Expr[G]): Option[Sign] = {
isConstantInt(e).map(i => isPos(i>=0)) orElse
lowerBound(e).flatMap(i => if(i>=0) Some(Pos()) else None) orElse
upperBound(e).flatMap(i => if(i<0) Some(Neg()) else None) orElse
lessThenEq(const(0)(e.o), e).map(isPos)
}

def unfoldComm[B <: BinExpr[G]](
e: Expr[G]
Expand All @@ -307,71 +363,74 @@ class ExpressionEqualityCheck[G](info: Option[AnnotationVariableInfo[G]]) {
}
}

//
def equalExpressionsRecurse(lhs: Expr[G], rhs: Expr[G]): Boolean = {
(isConstantInt(lhs), isConstantInt(rhs)) match {
case (Some(i1), Some(i2)) => return i1 == i2
case (None, None) => ()
// If one is a constant expression, and the other is not, this cannot be the same
case _ => return false
def partitionOptionList[A, B](
xs: Seq[A],
f: A => Option[B],
): (Seq[A], Seq[B]) = {
var resLeft: Seq[A] = Seq()
var resRight: Seq[B] = Seq()
for (x <- xs) {
f(x) match {
case Some(b) => resRight ++= Seq(b)
case None => resLeft ++= Seq(x)
}
}
(resLeft, resRight)
}

def partitionOptionList[A, B](
xs: Seq[A],
f: A => Option[B],
): (Seq[A], Seq[B]) = {
var resLeft: Seq[A] = Seq()
var resRight: Seq[B] = Seq()
for (x <- xs) {
f(x) match {
case Some(b) => resRight ++= Seq(b)
case None => resLeft ++= Seq(x)
def removeEqExprs(e1s: Seq[Expr[G]], e2s: Seq[Expr[G]]): (Seq[Expr[G]], Seq[Expr[G]]) = {
var resultingE2: Seq[Expr[G]] = e2s
var resultingE1: Seq[Expr[G]] = Seq()

for (x <- e1s) {
var found = false
val freezeAvailable = resultingE2
for (y <- freezeAvailable) {
if (!found && equalExpressionsRecurse(x, y)) {
found = true
resultingE2 = resultingE2.diff(Seq(y))
}
}
(resLeft, resRight)
if (!found) resultingE1 = resultingE1 :+ x
}
(resultingE1, resultingE2)
}

def commAssoc[B <: BinExpr[G]](e1: B, e2: B)(
implicit tag: ClassTag[B]
): Boolean = {
val e1s = unfoldComm[B](e1)
val e2s = unfoldComm[B](e2)

val (e1rest, e1Ints) = partitionOptionList(e1s, isConstantInt)
val (e2rest, e2Ints) = partitionOptionList(e2s, isConstantInt)

if (e1rest.size != e2rest.size)
return false

val res1: Boolean =
e1 match {
case _: Plus[G] => e1Ints.sum == e2Ints.sum
case _: Mult[G] => e1Ints.product == e2Ints.product
// Should not be reachable
case _ => ???
}
if (!res1)
return false
def commAssoc[B <: BinExpr[G]](e1: B, e2: B)(
implicit tag: ClassTag[B]
): Boolean = {
val e1s = unfoldComm[B](e1)
val e2s = unfoldComm[B](e2)

var available: Seq[Expr[G]] = e2rest
val (e1rest, e1Ints) = partitionOptionList(e1s, isConstantInt)
val (e2rest, e2Ints) = partitionOptionList(e2s, isConstantInt)

for (x <- e1rest) {
var found = false
val freezeAvailable = available
if (e1rest.size != e2rest.size)
return false

for (y <- freezeAvailable) {
if (!found && equalExpressionsRecurse(x, y)) {
found = true
available = available.diff(Seq(y))
}
}
if (!found)
return false
val res1: Boolean =
e1 match {
case _: Plus[G] => e1Ints.sum == e2Ints.sum
case _: Mult[G] => e1Ints.product == e2Ints.product
// Should not be reachable
case _ => ???
}
if (!res1)
return false

val (e1restrest, e2restrest) = removeEqExprs(e1rest, e2rest)
return e1restrest.isEmpty && e2restrest.isEmpty
}

true
def equalExpressionsRecurse(lhs: Expr[G], rhs: Expr[G]): Boolean = {
(isConstantInt(lhs), isConstantInt(rhs)) match {
case (Some(i1), Some(i2)) => return i1 == i2
case (None, None) => ()
// If one is a constant expression, and the other is not, this cannot be the same
case _ => return false
}


def comm(
lhs1: Expr[G],
lhs2: Expr[G],
Expand Down Expand Up @@ -629,8 +688,8 @@ class AnnotationVariableInfoGetter[G]() {
}
case Less(e1, e2) => lt(e1, e2, equal = false)
case LessEq(e1, e2) => lt(e1, e2, equal = true)
case Greater(e1, e2) => lt(e2, e1, equal = true)
case GreaterEq(e1, e2) => lt(e2, e1, equal = false)
case Greater(e1, e2) => lt(e2, e1, equal = false)
case GreaterEq(e1, e2) => lt(e2, e1, equal = true)
case SeqMember(e1, Range(from, to)) =>
lt(from, e1, equal = true)
lt(e1, to, equal = false)
Expand Down
8 changes: 8 additions & 0 deletions src/col/vct/col/util/AstBuildHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import vct.col.ref.{DirectRef, Ref}
import vct.col.rewrite.Rewritten
import vct.result.VerificationError.{Unreachable, UserError}

import scala.language.implicitConversions

object Conversions {

}

/** Collection of general AST building utilities. This is meant to organically
* grow, so add helpers as you see fit.
*/
Expand Down Expand Up @@ -49,6 +55,8 @@ object AstBuildHelpers {
def +(right: Expr[G])(implicit origin: Origin): Plus[G] = Plus(left, right)
def -(right: Expr[G])(implicit origin: Origin): Minus[G] =
Minus(left, right)
def unary_-(implicit origin: Origin): UMinus[G] =
UMinus(left)
def *(right: Expr[G])(implicit origin: Origin): Mult[G] = Mult(left, right)
def /(
right: Expr[G]
Expand Down
Loading

0 comments on commit cf5b6de

Please sign in to comment.