diff --git a/booster/library/Booster/SMT/Interface.hs b/booster/library/Booster/SMT/Interface.hs index e131435080..8a7a893e92 100644 --- a/booster/library/Booster/SMT/Interface.hs +++ b/booster/library/Booster/SMT/Interface.hs @@ -94,10 +94,10 @@ declareVariables transState = do -} initSolver :: Log.LoggerMIO io => KoreDefinition -> SMTOptions -> io SMT.SMTContext initSolver def smtOptions = Log.withContext Log.CtxSMT $ do - prelude <- translatePrelude def + (prelude, lemmas) <- translatePrelude def Log.logMessage ("Starting new SMT solver" :: Text) - ctxt <- mkContext smtOptions prelude + ctxt <- mkContext smtOptions prelude lemmas evalSMT ctxt (runPrelude CheckSMTPrelude) Log.logMessage ("Successfully initialised SMT solver with " <> (Text.pack . show $ smtOptions)) @@ -115,10 +115,11 @@ noSolver = do , solverClose , mbTranscriptHandle = Nothing , prelude = [] + , lemmas = mempty , options = defaultSMTOptions{retryLimit = Just 0} } --- | Stop the solver, initialise a new one, set the timeout and re-check the prelude +-- | Stop the solver, initialise a new one, set the timeout and re-run prelude (without check) hardResetSolver :: Log.LoggerMIO io => SMT io () hardResetSolver = do ctxt <- SMT get @@ -147,14 +148,17 @@ retry cb onTimeout = do cb _ -> onTimeout -translatePrelude :: Log.LoggerMIO io => KoreDefinition -> io [DeclareCommand] +translatePrelude :: + Log.LoggerMIO io => + KoreDefinition -> + io ([DeclareCommand], Map SymbolName (Set DeclareCommand)) translatePrelude def = let prelude = smtDeclarations def in case prelude of Left err -> do Log.logMessage $ "Error translating definition to SMT: " <> err throwSMT $ "Unable to translate elements of the definition to SMT: " <> err - Right decls -> pure decls + Right (decls, lemmas) -> pure (decls, lemmas) pattern CheckSMTPrelude, NoCheckSMTPrelude :: Flag "CheckSMTPrelude" pattern CheckSMTPrelude = Flag True @@ -168,8 +172,11 @@ runPrelude doCheck = do Log.logMessage ("Checking definition prelude" :: Text) -- send the commands from the definition's SMT prelude mapM_ runCmd ctxt.prelude - -- optionally check the prelude for consistency + -- optionally check prelude and lemmas for consistency when (coerce doCheck) $ do + -- add all lemmas for the consistency check + let allLemmas = Set.toList $ Set.unions $ Map.elems ctxt.lemmas + mapM_ runCmd allLemmas check <- runCmd CheckSat case check of Sat -> pure () @@ -219,8 +226,10 @@ isSatReturnTransState ctxt ps subst Log.withContext Log.CtxAbort $ Log.logMessage $ "SMT translation error: " <> errMsg smtTranslateError errMsg | Right (smtToCheck, transState) <- translated = Log.withContext Log.CtxSMT $ do + -- add relevant SMT lemmas to the SMT assertions + let lemmas = selectLemmas ctxt.lemmas ps evalSMT ctxt $ - hardResetSolver >> solve smtToCheck transState + hardResetSolver >> solve (lemmas <> smtToCheck) transState where translated :: Either Text ([DeclareCommand], TranslationState) translated = @@ -393,8 +402,10 @@ checkPredicates ctxt givenPs givenSubst psToCheck Log.withContext Log.CtxAbort $ Log.logMessage $ "SMT translation error: " <> errMsg smtTranslateError errMsg | Right ((smtGiven, sexprsToCheck), transState) <- translated = Log.withContext Log.CtxSMT $ do + -- add relevant SMT lemmas to smtGiven + let lemmas = selectLemmas ctxt.lemmas (Set.toList $ givenPs <> psToCheck) evalSMT ctxt $ - hardResetSolver >> solve smtGiven sexprsToCheck transState + hardResetSolver >> solve (smtGiven <> lemmas) sexprsToCheck transState where solve :: [DeclareCommand] -> diff --git a/booster/library/Booster/SMT/Runner.hs b/booster/library/Booster/SMT/Runner.hs index b1eedbd084..3e03966405 100644 --- a/booster/library/Booster/SMT/Runner.hs +++ b/booster/library/Booster/SMT/Runner.hs @@ -30,7 +30,9 @@ import Control.Monad.Trans.State import Data.ByteString.Builder qualified as BS import Data.ByteString.Char8 qualified as BS import Data.IORef +import Data.Map (Map) import Data.Maybe (fromMaybe) +import Data.Set (Set) import Data.Text (Text, pack) import SMTLIB.Backends qualified as Backend import SMTLIB.Backends.Process qualified as Backend @@ -82,22 +84,20 @@ data SMTContext = SMTContext , solverClose :: IORef (IO ()) , mbTranscriptHandle :: Maybe Handle , prelude :: [DeclareCommand] + , lemmas :: Map SymbolName (Set DeclareCommand) } +type SymbolName = BS.ByteString -- replicated from Booster.Pattern.Base + ---------------------------------------- -{- TODO (later) -- error handling and retries - - retry counter in context -- (possibly) run `get-info` on Unknown responses and enhance Unknown constructor - - smtlib2: reason-unknown = memout | incomplete | SExpr --} mkContext :: LoggerMIO io => SMTOptions -> [DeclareCommand] -> + Map SymbolName (Set DeclareCommand) -> io SMTContext -mkContext opts prelude = do +mkContext opts prelude lemmas = do logMessage ("Starting SMT solver" :: Text) (solver', handle) <- connectToSolver opts.args solver <- liftIO $ newIORef solver' @@ -118,6 +118,7 @@ mkContext opts prelude = do , solverClose , mbTranscriptHandle , prelude + , lemmas , options = opts } diff --git a/booster/library/Booster/SMT/Translate.hs b/booster/library/Booster/SMT/Translate.hs index dd721d8357..ce1da0a92f 100644 --- a/booster/library/Booster/SMT/Translate.hs +++ b/booster/library/Booster/SMT/Translate.hs @@ -15,6 +15,7 @@ module Booster.SMT.Translate ( backTranslateFrom, runTranslator, smtSort, + selectLemmas, ) where import Control.Monad @@ -25,9 +26,11 @@ import Data.Bifunctor (first) import Data.ByteString.Char8 qualified as BS import Data.Char (isDigit) import Data.Coerce (coerce) +import Data.Foldable (toList) import Data.Map (Map) import Data.Map qualified as Map import Data.Maybe (catMaybes, fromMaybe, mapMaybe) +import Data.Set (Set) import Data.Set qualified as Set import Data.Text (Text, pack) import Prettyprinter (pretty) @@ -39,7 +42,7 @@ import Booster.Definition.Base import Booster.Pattern.Base import Booster.Pattern.Bool import Booster.Pattern.Pretty -import Booster.Pattern.Util (sortOfTerm) +import Booster.Pattern.Util (filterTermSymbols, isFunctionSymbol, sortOfTerm) import Booster.Prettyprinter qualified as Pretty import Booster.SMT.Base as SMT import Booster.SMT.LowLevelCodec as SMT @@ -261,7 +264,8 @@ equationToSMTLemma equation List [Atom "forall", List varPairs, lemmaRaw] -- collect and render all declarations from a definition -smtDeclarations :: KoreDefinition -> Either Text [DeclareCommand] +smtDeclarations :: + KoreDefinition -> Either Text ([DeclareCommand], Map SymbolName (Set DeclareCommand)) smtDeclarations def | Left msg <- translatedLemmas = Left $ "Lemma translation failed: " <> msg @@ -269,7 +273,7 @@ smtDeclarations def , not (Map.null finalState.mappings) = Left . pack $ "Unexpected final state " <> show (finalState.mappings, finalState.counter) | Right (lemmas, _) <- translatedLemmas = - Right $ concat [sortDecls, funDecls, lemmas] + Right (sortDecls <> funDecls, lemmas) where -- declare all sorts except Int and Bool sortDecls = @@ -282,14 +286,68 @@ smtDeclarations def funDecls = mapMaybe declareFunc $ Map.elems def.symbols - -- declare all SMT lemmas as assertions - allRules :: Map k (Map k' [v]) -> [v] - allRules = concat . concatMap Map.elems . Map.elems - extractLemmas = fmap catMaybes . mapM equationToSMTLemma . allRules + -- declare all SMT lemmas as assertions and construct a lookup table Symbol -> Lemmas + allSMTEquations :: Theory (RewriteRule t) -> Translator [(RewriteRule t, DeclareCommand)] + allSMTEquations = + fmap catMaybes + . mapM (\e -> fmap (e,) <$> equationToSMTLemma e) + . filter (coerce . (.attributes.smtLemma)) + . concat + . concatMap Map.elems + . Map.elems + -- collect function symbols of an equation (LHS + requires, RHS) + collectSymbols :: RewriteRule t -> ([SymbolName], [SymbolName]) + collectSymbols rule = + ( collectNames rule.lhs <> concatMap (collectNames . coerce) rule.requires + , collectNames rule.rhs + ) + + -- symbol used on LHS => lookup must include sym -> this rule + -- symbol used on RHS => lookups returning this rule must be + -- _extended_ by rules reachable from that symbol. Requires a + -- transitive closure of the lookup map. + + initialLookup :: + Theory (RewriteRule t) -> + Translator (Map SymbolName (Set (RewriteRule t, DeclareCommand))) + initialLookup = fmap (Map.unionsWith (<>) . map mapFrom) . allSMTEquations + where + mapFrom (eqn, smt) = + Map.fromList [(sym, Set.singleton (eqn, smt)) | sym <- fst $ collectSymbols eqn] + + closeOverSymbols :: + forall a t. + Ord a => + Map SymbolName (Set (RewriteRule t, a)) -> + Map SymbolName (Set (RewriteRule t, a)) + closeOverSymbols start = go start + where + keys = Map.keys start -- should not change + go :: + Map SymbolName (Set (RewriteRule t, a)) -> Map SymbolName (Set (RewriteRule t, a)) + go current = + let new = execState (mapM updateMapFor keys) current + in if new == current then new else go new + + updateMapFor :: + SymbolName -> State (Map SymbolName (Set (RewriteRule t, a))) () + updateMapFor k = do + m <- get + case Map.lookup k m of + Nothing -> pure () -- should not happen, keys won't change + Just eqs -> do + let rhsSyms = concatMap (snd . collectSymbols . fst) $ toList eqs + newEqs = Set.unions $ mapMaybe (flip Map.lookup m) rhsSyms + newM = Map.update (Just . (<> newEqs)) k m + put newM + + translatedLemmas :: Either Text (Map SymbolName (Set DeclareCommand), TranslationState) translatedLemmas = - runTranslator $ - (<>) <$> extractLemmas def.functionEquations <*> extractLemmas def.simplifications + let trans :: Theory (RewriteRule t) -> Translator (Map SymbolName (Set DeclareCommand)) + trans = fmap (Map.map (Set.map snd) . closeOverSymbols) . initialLookup + in runTranslator $ + (<>) <$> trans def.simplifications <*> trans def.functionEquations -- kore-rpc also declares all constructors, with no-junk axioms. WHY? @@ -304,6 +362,16 @@ smtDeclarations def (smtSort sym.resultSort) | otherwise = Nothing +-- | helper to select SMT lemmas from the context given a predicate to check +selectLemmas :: Map SymbolName (Set DeclareCommand) -> [Predicate] -> [DeclareCommand] +selectLemmas m ps = + Set.toList $ Set.unions $ mapMaybe (flip Map.lookup m) usedFcts + where + usedFcts = concatMap (collectNames . coerce) ps + +collectNames :: Term -> [SymbolName] +collectNames = map (.name) . filterTermSymbols isFunctionSymbol + smtName, quoted :: BS.ByteString -> SMTId smtName = SMTId -- All Kore sort names (except Int and Bool) need to be quoted |...| here. diff --git a/booster/unit-tests/Test/Booster/SMT/LowLevel.hs b/booster/unit-tests/Test/Booster/SMT/LowLevel.hs index c079ff897c..a386e3ce9b 100644 --- a/booster/unit-tests/Test/Booster/SMT/LowLevel.hs +++ b/booster/unit-tests/Test/Booster/SMT/LowLevel.hs @@ -69,7 +69,7 @@ declTests = -- otherwise they might just get queued. runSatAfter :: [SMTCommand] -> IO SMT.Response runSatAfter commands = runNoLoggingT $ do - ctxt <- mkContext defaultSMTOptions [] + ctxt <- mkContext defaultSMTOptions [] mempty result <- evalSMT ctxt $ mapM_ runCmd commands >> runCmd CheckSat closeContext ctxt pure result @@ -141,7 +141,7 @@ checkTests = where exec x = runNoLoggingT $ - mkContext defaultSMTOptions [] >>= \c -> evalSMT c x <* closeContext c + mkContext defaultSMTOptions [] mempty >>= \c -> evalSMT c x <* closeContext c test name result decls = testCase name $ (result @=?) =<< exec (runCheck decls) returns = ($)