Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter SMT lemmas for predicate checks and get-model #4048

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions booster/library/Booster/SMT/Interface.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ declareVariables transState = do
-}
initSolver :: Log.LoggerMIO io => KoreDefinition -> SMTOptions -> io SMT.SMTContext
initSolver def smtOptions = Log.withContext Log.CtxSMT $ do
prelude <- translatePrelude def
(prelude, lemmas) <- translatePrelude def

Log.logMessage ("Starting new SMT solver" :: Text)
ctxt <- mkContext smtOptions prelude
ctxt <- mkContext smtOptions prelude lemmas

evalSMT ctxt (runPrelude CheckSMTPrelude)
Log.logMessage ("Successfully initialised SMT solver with " <> (Text.pack . show $ smtOptions))
Expand All @@ -115,10 +115,11 @@ noSolver = do
, solverClose
, mbTranscriptHandle = Nothing
, prelude = []
, lemmas = mempty
, options = defaultSMTOptions{retryLimit = Just 0}
}

-- | Stop the solver, initialise a new one, set the timeout and re-check the prelude
-- | Stop the solver, initialise a new one, set the timeout and re-run prelude (without check)
hardResetSolver :: Log.LoggerMIO io => SMT io ()
hardResetSolver = do
ctxt <- SMT get
Expand Down Expand Up @@ -147,14 +148,17 @@ retry cb onTimeout = do
cb
_ -> onTimeout

translatePrelude :: Log.LoggerMIO io => KoreDefinition -> io [DeclareCommand]
translatePrelude ::
Log.LoggerMIO io =>
KoreDefinition ->
io ([DeclareCommand], Map SymbolName (Set DeclareCommand))
translatePrelude def =
let prelude = smtDeclarations def
in case prelude of
Left err -> do
Log.logMessage $ "Error translating definition to SMT: " <> err
throwSMT $ "Unable to translate elements of the definition to SMT: " <> err
Right decls -> pure decls
Right (decls, lemmas) -> pure (decls, lemmas)

pattern CheckSMTPrelude, NoCheckSMTPrelude :: Flag "CheckSMTPrelude"
pattern CheckSMTPrelude = Flag True
Expand All @@ -168,8 +172,11 @@ runPrelude doCheck = do
Log.logMessage ("Checking definition prelude" :: Text)
-- send the commands from the definition's SMT prelude
mapM_ runCmd ctxt.prelude
-- optionally check the prelude for consistency
-- optionally check prelude and lemmas for consistency
when (coerce doCheck) $ do
-- add all lemmas for the consistency check
let allLemmas = Set.toList $ Set.unions $ Map.elems ctxt.lemmas
mapM_ runCmd allLemmas
check <- runCmd CheckSat
case check of
Sat -> pure ()
Expand Down Expand Up @@ -219,8 +226,10 @@ isSatReturnTransState ctxt ps subst
Log.withContext Log.CtxAbort $ Log.logMessage $ "SMT translation error: " <> errMsg
smtTranslateError errMsg
| Right (smtToCheck, transState) <- translated = Log.withContext Log.CtxSMT $ do
-- add relevant SMT lemmas to the SMT assertions
let lemmas = selectLemmas ctxt.lemmas ps
evalSMT ctxt $
hardResetSolver >> solve smtToCheck transState
hardResetSolver >> solve (lemmas <> smtToCheck) transState
where
translated :: Either Text ([DeclareCommand], TranslationState)
translated =
Expand Down Expand Up @@ -393,8 +402,10 @@ checkPredicates ctxt givenPs givenSubst psToCheck
Log.withContext Log.CtxAbort $ Log.logMessage $ "SMT translation error: " <> errMsg
smtTranslateError errMsg
| Right ((smtGiven, sexprsToCheck), transState) <- translated = Log.withContext Log.CtxSMT $ do
-- add relevant SMT lemmas to smtGiven
let lemmas = selectLemmas ctxt.lemmas (Set.toList $ givenPs <> psToCheck)
evalSMT ctxt $
hardResetSolver >> solve smtGiven sexprsToCheck transState
hardResetSolver >> solve (smtGiven <> lemmas) sexprsToCheck transState
where
solve ::
[DeclareCommand] ->
Expand Down
15 changes: 8 additions & 7 deletions booster/library/Booster/SMT/Runner.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import Control.Monad.Trans.State
import Data.ByteString.Builder qualified as BS
import Data.ByteString.Char8 qualified as BS
import Data.IORef
import Data.Map (Map)
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import Data.Text (Text, pack)
import SMTLIB.Backends qualified as Backend
import SMTLIB.Backends.Process qualified as Backend
Expand Down Expand Up @@ -82,22 +84,20 @@ data SMTContext = SMTContext
, solverClose :: IORef (IO ())
, mbTranscriptHandle :: Maybe Handle
, prelude :: [DeclareCommand]
, lemmas :: Map SymbolName (Set DeclareCommand)
}

type SymbolName = BS.ByteString -- replicated from Booster.Pattern.Base

----------------------------------------
{- TODO (later)
- error handling and retries
- retry counter in context
- (possibly) run `get-info` on Unknown responses and enhance Unknown constructor
- smtlib2: reason-unknown = memout | incomplete | SExpr
-}

mkContext ::
LoggerMIO io =>
SMTOptions ->
[DeclareCommand] ->
Map SymbolName (Set DeclareCommand) ->
io SMTContext
mkContext opts prelude = do
mkContext opts prelude lemmas = do
logMessage ("Starting SMT solver" :: Text)
(solver', handle) <- connectToSolver opts.args
solver <- liftIO $ newIORef solver'
Expand All @@ -118,6 +118,7 @@ mkContext opts prelude = do
, solverClose
, mbTranscriptHandle
, prelude
, lemmas
, options = opts
}

Expand Down
86 changes: 77 additions & 9 deletions booster/library/Booster/SMT/Translate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module Booster.SMT.Translate (
backTranslateFrom,
runTranslator,
smtSort,
selectLemmas,
) where

import Control.Monad
Expand All @@ -25,9 +26,11 @@ import Data.Bifunctor (first)
import Data.ByteString.Char8 qualified as BS
import Data.Char (isDigit)
import Data.Coerce (coerce)
import Data.Foldable (toList)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Text (Text, pack)
import Prettyprinter (pretty)
Expand All @@ -39,7 +42,7 @@ import Booster.Definition.Base
import Booster.Pattern.Base
import Booster.Pattern.Bool
import Booster.Pattern.Pretty
import Booster.Pattern.Util (sortOfTerm)
import Booster.Pattern.Util (filterTermSymbols, isFunctionSymbol, sortOfTerm)
import Booster.Prettyprinter qualified as Pretty
import Booster.SMT.Base as SMT
import Booster.SMT.LowLevelCodec as SMT
Expand Down Expand Up @@ -261,15 +264,16 @@ equationToSMTLemma equation
List [Atom "forall", List varPairs, lemmaRaw]

-- collect and render all declarations from a definition
smtDeclarations :: KoreDefinition -> Either Text [DeclareCommand]
smtDeclarations ::
KoreDefinition -> Either Text ([DeclareCommand], Map SymbolName (Set DeclareCommand))
smtDeclarations def
| Left msg <- translatedLemmas =
Left $ "Lemma translation failed: " <> msg
| Right (_, finalState) <- translatedLemmas
, not (Map.null finalState.mappings) =
Left . pack $ "Unexpected final state " <> show (finalState.mappings, finalState.counter)
| Right (lemmas, _) <- translatedLemmas =
Right $ concat [sortDecls, funDecls, lemmas]
Right (sortDecls <> funDecls, lemmas)
where
-- declare all sorts except Int and Bool
sortDecls =
Expand All @@ -282,14 +286,68 @@ smtDeclarations def
funDecls =
mapMaybe declareFunc $ Map.elems def.symbols

-- declare all SMT lemmas as assertions
allRules :: Map k (Map k' [v]) -> [v]
allRules = concat . concatMap Map.elems . Map.elems
extractLemmas = fmap catMaybes . mapM equationToSMTLemma . allRules
-- declare all SMT lemmas as assertions and construct a lookup table Symbol -> Lemmas
allSMTEquations :: Theory (RewriteRule t) -> Translator [(RewriteRule t, DeclareCommand)]
allSMTEquations =
fmap catMaybes
. mapM (\e -> fmap (e,) <$> equationToSMTLemma e)
. filter (coerce . (.attributes.smtLemma))
. concat
. concatMap Map.elems
. Map.elems

-- collect function symbols of an equation (LHS + requires, RHS)
collectSymbols :: RewriteRule t -> ([SymbolName], [SymbolName])
collectSymbols rule =
( collectNames rule.lhs <> concatMap (collectNames . coerce) rule.requires
, collectNames rule.rhs
)

-- symbol used on LHS => lookup must include sym -> this rule
-- symbol used on RHS => lookups returning this rule must be
-- _extended_ by rules reachable from that symbol. Requires a
-- transitive closure of the lookup map.

initialLookup ::
Theory (RewriteRule t) ->
Translator (Map SymbolName (Set (RewriteRule t, DeclareCommand)))
initialLookup = fmap (Map.unionsWith (<>) . map mapFrom) . allSMTEquations
where
mapFrom (eqn, smt) =
Map.fromList [(sym, Set.singleton (eqn, smt)) | sym <- fst $ collectSymbols eqn]

closeOverSymbols ::
forall a t.
Ord a =>
Map SymbolName (Set (RewriteRule t, a)) ->
Map SymbolName (Set (RewriteRule t, a))
closeOverSymbols start = go start
where
keys = Map.keys start -- should not change
go ::
Map SymbolName (Set (RewriteRule t, a)) -> Map SymbolName (Set (RewriteRule t, a))
go current =
let new = execState (mapM updateMapFor keys) current
in if new == current then new else go new

updateMapFor ::
SymbolName -> State (Map SymbolName (Set (RewriteRule t, a))) ()
updateMapFor k = do
m <- get
case Map.lookup k m of
Nothing -> pure () -- should not happen, keys won't change
Just eqs -> do
let rhsSyms = concatMap (snd . collectSymbols . fst) $ toList eqs
newEqs = Set.unions $ mapMaybe (flip Map.lookup m) rhsSyms
newM = Map.update (Just . (<> newEqs)) k m
put newM

translatedLemmas :: Either Text (Map SymbolName (Set DeclareCommand), TranslationState)
translatedLemmas =
runTranslator $
(<>) <$> extractLemmas def.functionEquations <*> extractLemmas def.simplifications
let trans :: Theory (RewriteRule t) -> Translator (Map SymbolName (Set DeclareCommand))
trans = fmap (Map.map (Set.map snd) . closeOverSymbols) . initialLookup
in runTranslator $
(<>) <$> trans def.simplifications <*> trans def.functionEquations

-- kore-rpc also declares all constructors, with no-junk axioms. WHY?

Expand All @@ -304,6 +362,16 @@ smtDeclarations def
(smtSort sym.resultSort)
| otherwise = Nothing

-- | helper to select SMT lemmas from the context given a predicate to check
selectLemmas :: Map SymbolName (Set DeclareCommand) -> [Predicate] -> [DeclareCommand]
selectLemmas m ps =
Set.toList $ Set.unions $ mapMaybe (flip Map.lookup m) usedFcts
where
usedFcts = concatMap (collectNames . coerce) ps

collectNames :: Term -> [SymbolName]
collectNames = map (.name) . filterTermSymbols isFunctionSymbol

smtName, quoted :: BS.ByteString -> SMTId
smtName = SMTId
-- All Kore sort names (except Int and Bool) need to be quoted |...| here.
Expand Down
4 changes: 2 additions & 2 deletions booster/unit-tests/Test/Booster/SMT/LowLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ declTests =
-- otherwise they might just get queued.
runSatAfter :: [SMTCommand] -> IO SMT.Response
runSatAfter commands = runNoLoggingT $ do
ctxt <- mkContext defaultSMTOptions []
ctxt <- mkContext defaultSMTOptions [] mempty
result <- evalSMT ctxt $ mapM_ runCmd commands >> runCmd CheckSat
closeContext ctxt
pure result
Expand Down Expand Up @@ -141,7 +141,7 @@ checkTests =
where
exec x =
runNoLoggingT $
mkContext defaultSMTOptions [] >>= \c -> evalSMT c x <* closeContext c
mkContext defaultSMTOptions [] mempty >>= \c -> evalSMT c x <* closeContext c
test name result decls =
testCase name $ (result @=?) =<< exec (runCheck decls)
returns = ($)
Loading