diff --git a/src/Language/Hasmtlib/Internal/Disjoin.hs b/src/Language/Hasmtlib/Internal/Disjoin.hs index 0e6ecf8..c5945fe 100644 --- a/src/Language/Hasmtlib/Internal/Disjoin.hs +++ b/src/Language/Hasmtlib/Internal/Disjoin.hs @@ -16,6 +16,7 @@ import Data.Sequence as Seq hiding ((|>), (<|)) import Data.Maybe import Data.Coerce import Data.STRef +import Data.List (nub) import qualified Data.Foldable as Foldable import Control.Lens import Control.Monad.ST @@ -48,8 +49,8 @@ disjoinST s = runST $ do modifySTRef' vId_fId_Ref (<> IntMap.fromList (fmap (, formulaId) vs)) _ -> do fId_fs <- readSTRef fId_fs_Ref - let mergedFs = f <| join (Seq.fromList $ mapMaybe (fId_fs IntMap.!?) fIds) - writeSTRef fId_fs_Ref $ Prelude.foldr IntMap.delete fId_fs fIds -- delete old formula associations + let mergedFs = f <| join (Seq.fromList $ nub $ mapMaybe (fId_fs IntMap.!?) fIds) + writeSTRef fId_fs_Ref $ Foldable.foldr' IntMap.delete fId_fs fIds -- delete old formula associations modifySTRef' fId_fs_Ref (at formulaId ?~ mergedFs) modifySTRef' vId_fId_Ref (IntMap.fromList (fmap (, formulaId) $ IntSet.toList $ varIdsAll mergedFs) <>) -- update new var associations diff --git a/src/Language/Hasmtlib/Internal/Expr.hs b/src/Language/Hasmtlib/Internal/Expr.hs index 7d9e6a6..07d2e06 100644 --- a/src/Language/Hasmtlib/Internal/Expr.hs +++ b/src/Language/Hasmtlib/Internal/Expr.hs @@ -4,6 +4,7 @@ module Language.Hasmtlib.Internal.Expr where +import Prelude hiding ((&&)) import Language.Hasmtlib.Internal.Render import Language.Hasmtlib.Type.ArrayMap import Language.Hasmtlib.Type.SMTSort @@ -12,6 +13,7 @@ import Data.Map hiding (toList) import Data.List (intercalate) import Data.Proxy import Data.Coerce +import Data.GADT.Compare import Data.Foldable (toList) import Data.ByteString.Builder import qualified Data.Vector.Sized as V @@ -28,8 +30,9 @@ data Value (t :: SMTSort) where IntValue :: HaskellType IntSort -> Value IntSort RealValue :: HaskellType RealSort -> Value RealSort BoolValue :: HaskellType BoolSort -> Value BoolSort - BvValue :: HaskellType (BvSort n) -> Value (BvSort n) - ArrayValue :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k)) => HaskellType (ArraySort k v) -> Value (ArraySort k v) + BvValue :: KnownNat n => HaskellType (BvSort n) -> Value (BvSort n) + ArrayValue :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k), Eq (HaskellType v)) + => HaskellType (ArraySort k v) -> Value (ArraySort k v) -- | Unwrap a value from 'Value'. unwrapValue :: Value t -> HaskellType t @@ -57,7 +60,7 @@ type SomeKnownSMTSort f = SomeSMTSort '[KnownSMTSort] f -- For internal use only. -- For building expressions use the corresponding instances (Num, Boolean, ...). data Expr (t :: SMTSort) where - Var :: SMTVar t -> Expr t + Var :: KnownSMTSort t => SMTVar t -> Expr t Constant :: Value t -> Expr t Plus :: Num (HaskellType t) => Expr t -> Expr t -> Expr t @@ -119,13 +122,107 @@ data Expr (t :: SMTSort) where BvuGTHE :: KnownNat n => Expr (BvSort n) -> Expr (BvSort n) -> Expr BoolSort BvuGT :: KnownNat n => Expr (BvSort n) -> Expr (BvSort n) -> Expr BoolSort - ArrSelect :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k)) => Expr (ArraySort k v) -> Expr k -> Expr v + ArrSelect :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k), Eq (HaskellType v)) => Expr (ArraySort k v) -> Expr k -> Expr v ArrStore :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k)) => Expr (ArraySort k v) -> Expr k -> Expr v -> Expr (ArraySort k v) -- Just v if quantified var has been created already, Nothing otherwise ForAll :: KnownSMTSort t => Maybe (SMTVar t) -> (Expr t -> Expr BoolSort) -> Expr BoolSort Exists :: KnownSMTSort t => Maybe (SMTVar t) -> (Expr t -> Expr BoolSort) -> Expr BoolSort +instance GEq SMTVar where + geq x y = _ + +instance GEq Value where + geq (BoolValue x) (BoolValue y) = if x == y then Just Refl else Nothing + geq (IntValue x) (IntValue y) = if x == y then Just Refl else Nothing + geq (RealValue x) (RealValue y) = if x == y then Just Refl else Nothing + geq (BvValue x) (BvValue y) = case cmpNat x y of + EQI -> if x == y then Just Refl else Nothing + _ -> Nothing + geq ax@(ArrayValue x) ay@(ArrayValue y) = case geq (sortSing' ax) (sortSing' ay) of + Nothing -> Nothing + Just Refl -> if x == y then Just Refl else Nothing + geq _ _ = Nothing + +instance GEq Expr where + geq (Var x) (Var y) = geq x y + geq (Constant x) (Constant y) = geq x y + geq (Plus x y) (Plus x' y') = case geq x x' of Nothing -> Nothing ; Just Refl -> geq y y' + geq (Neg x) (Neg y) = geq x y + geq (Mul x y) (Mul x' y') = case geq x x' of Nothing -> Nothing ; Just Refl -> geq y y' + geq (Abs x) (Abs y) = geq x y + geq (Mod x y) (Mod x' y') = case geq x x' of Nothing -> Nothing ; Just Refl -> geq y y' + geq (IDiv x y) (IDiv x' y') = case geq x x' of Nothing -> Nothing ; Just Refl -> geq y y' + geq (LTH x y) (LTH x' y') = case geq x x' of + Nothing -> Nothing + Just Refl -> case geq y y' of + Nothing -> Nothing + Just Refl -> Just Refl + geq (LTHE x y) (LTHE x' y') = case geq x x' of + Nothing -> Nothing + Just Refl -> case geq y y' of + Nothing -> Nothing + Just Refl -> Just Refl + geq (EQU xs) (EQU ys) = _ + geq (Distinct xs) (Distinct ys) = _ + geq (GTHE x y) (GTHE x' y') = case geq x x' of + Nothing -> Nothing + Just Refl -> case geq y y' of + Nothing -> Nothing + Just Refl -> Just Refl + geq (GTH x y) (GTH x' y') = case geq x x' of + Nothing -> Nothing + Just Refl -> case geq y y' of + Nothing -> Nothing + Just Refl -> Just Refl + geq _ _ = Nothing + + +instance Eq (Value t) => Eq (Expr t) where + -- (Not x) = vars1 x + -- (And x y) = vars1 x <> vars1 y + -- (Or x y) = vars1 x <> vars1 y + -- (Impl x y) = vars1 x <> vars1 y + -- (Xor x y) = vars1 x <> vars1 y + -- Pi = mempty + -- (Sqrt x) = vars1 x + -- (Exp x) = vars1 x + -- (Sin x) = vars1 x + -- (Cos x) = vars1 x + -- (Tan x) = vars1 x + -- (Asin x) = vars1 x + -- (Acos x) = vars1 x + -- (Atan x) = vars1 x + -- (ToReal x) = vars1 x + -- (ToInt x) = vars1 x + -- (IsInt x) = vars1 x + -- (Ite p t f) = vars1 p <> vars1 t <> vars1 f + -- (BvNot x) = vars1 x + -- (BvAnd x y) = vars1 x <> vars1 y + -- (BvOr x y) = vars1 x <> vars1 y + -- (BvXor x y) = vars1 x <> vars1 y + -- (BvNand x y) = vars1 x <> vars1 y + -- (BvNor x y) = vars1 x <> vars1 y + -- (BvNeg x) = vars1 x + -- (BvAdd x y) = vars1 x <> vars1 y + -- (BvSub x y) = vars1 x <> vars1 y + -- (BvMul x y) = vars1 x <> vars1 y + -- (BvuDiv x y) = vars1 x <> vars1 y + -- (BvuRem x y) = vars1 x <> vars1 y + -- (BvShL x y) = vars1 x <> vars1 y + -- (BvLShR x y) = vars1 x <> vars1 y + -- (BvConcat x y) = vars1 x <> vars1 y + -- (BvRotL _ x) = vars1 x + -- (BvRotR _ x) = vars1 x + -- (BvuLT x y) = vars1 x <> vars1 y + -- (BvuLTHE x y) = vars1 x <> vars1 y + -- (BvuGTHE x y) = vars1 x <> vars1 y + -- (BvuGT x y) = vars1 x <> vars1 y + -- (ArrSelect i arr) = vars1 i <> vars1 arr + -- (ArrStore i x arr) = vars1 i <> vars1 x <> vars1 arr + -- (ForAll mQv expr) = vars1Q mQv expr + -- (Exists mQv expr) = vars1Q mQv expr + instance Boolean (Expr BoolSort) where bool = Constant . BoolValue {-# INLINE bool #-} @@ -174,7 +271,8 @@ instance Render (Value t) where | otherwise -> constRender v where constRender v = "((as const " <> render (goSing arr) <> ") " <> render (wrapValue v) <> ")" - goSing :: forall k v. (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k)) => ConstArray (HaskellType k) (HaskellType v) -> SSMTSort (ArraySort k v) + goSing :: forall k v. (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k), Eq (HaskellType v)) + => ConstArray (HaskellType k) (HaskellType v) -> SSMTSort (ArraySort k v) goSing _ = sortSing @(ArraySort k v) instance KnownSMTSort t => Render (Expr t) where diff --git a/src/Language/Hasmtlib/Type/SMTSort.hs b/src/Language/Hasmtlib/Type/SMTSort.hs index c69ffb1..456acb2 100644 --- a/src/Language/Hasmtlib/Type/SMTSort.hs +++ b/src/Language/Hasmtlib/Type/SMTSort.hs @@ -36,7 +36,7 @@ data SSMTSort (t :: SMTSort) where SRealSort :: SSMTSort RealSort SBoolSort :: SSMTSort BoolSort SBvSort :: KnownNat n => Proxy n -> SSMTSort (BvSort n) - SArraySort :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k)) => Proxy k -> Proxy v -> SSMTSort (ArraySort k v) + SArraySort :: (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k), Eq (HaskellType v)) => Proxy k -> Proxy v -> SSMTSort (ArraySort k v) deriving instance Show (SSMTSort t) deriving instance Eq (SSMTSort t) @@ -81,7 +81,7 @@ instance KnownSMTSort IntSort where sortSing = SIntSort instance KnownSMTSort RealSort where sortSing = SRealSort instance KnownSMTSort BoolSort where sortSing = SBoolSort instance KnownNat n => KnownSMTSort (BvSort n) where sortSing = SBvSort (Proxy @n) -instance (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k)) => KnownSMTSort (ArraySort k v) where +instance (KnownSMTSort k, KnownSMTSort v, Ord (HaskellType k), Eq (HaskellType v)) => KnownSMTSort (ArraySort k v) where sortSing = SArraySort (Proxy @k) (Proxy @v) -- | Wrapper for 'sortSing' which takes a 'Proxy'-like argument for @t@.