From 8571b558f374090af97d107e41376b19a5ce00e0 Mon Sep 17 00:00:00 2001 From: George Thomas Date: Thu, 26 Oct 2023 16:13:11 +0100 Subject: [PATCH] hack: Evaluate arbitrary lambdas in animation primitive Signed-off-by: George Thomas --- primer/src/Primer/Eval/Prim.hs | 2 +- primer/src/Primer/Eval/Redex.hs | 27 ++++++-- primer/src/Primer/Primitives.hs | 109 ++++++++++++++++---------------- 3 files changed, 78 insertions(+), 60 deletions(-) diff --git a/primer/src/Primer/Eval/Prim.hs b/primer/src/Primer/Eval/Prim.hs index a5c6f42b0..07d2e3b9c 100644 --- a/primer/src/Primer/Eval/Prim.hs +++ b/primer/src/Primer/Eval/Prim.hs @@ -34,7 +34,7 @@ data ApplyPrimFunDetail = ApplyPrimFunDetail -- | If this node is a reducible application of a primitive, return the name of the primitive, the arguments, and -- (a computation for building) the result. -tryPrimFun :: Map GVarName PrimDef -> Expr -> Maybe (GVarName, [Expr], forall m. MonadFresh ID m => m Expr) +tryPrimFun :: Map GVarName PrimDef -> Expr -> Maybe (GVarName, [Expr], forall m. MonadFresh ID m => (Expr -> m Expr) -> m Expr) tryPrimFun primDefs expr | -- Since no primitive functions are polymorphic, there is no need to unfoldAPP (Var _ (GlobalVarRef name), args) <- bimap stripAnns (map stripAnns) $ unfoldApp expr diff --git a/primer/src/Primer/Eval/Redex.hs b/primer/src/Primer/Eval/Redex.hs index 71bcf6738..8a1bed6db 100644 --- a/primer/src/Primer/Eval/Redex.hs +++ b/primer/src/Primer/Eval/Redex.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BlockArguments #-} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE OverloadedRecordDot #-} @@ -27,9 +28,9 @@ import Foreword import Control.Monad.Fresh (MonadFresh) import Control.Monad.Log (MonadLog, WithSeverity) -import Control.Monad.Trans.Maybe (MaybeT) +import Control.Monad.Trans.Maybe (MaybeT, runMaybeT) import Data.Data (Data) -import Data.Generics.Uniplate.Data (children, descendM) +import Data.Generics.Uniplate.Data (children, descendM, transformM) import Data.List (zip3) import Data.Map qualified as M import Data.Set qualified as S @@ -422,7 +423,7 @@ data Redex -- ^ The original redex (used for details) } | ApplyPrimFun - { result :: forall m. MonadFresh ID m => m Expr + { result :: forall m. MonadFresh ID m => (Expr -> m Expr) -> m Expr -- ^ The result of the applied primitive function , primFun :: GVarName -- ^ The applied primitive function (used for details) @@ -430,6 +431,8 @@ data Redex -- ^ The original arguments to @primFun@ (used for details) , orig :: Expr -- ^ The original redex (used for details) + , tydefs :: TypeDefMap + , globals :: DefMap } data RedexType @@ -756,7 +759,7 @@ viewRedex opts tydefs globals dir = \case $ hoistMaybe $ tryPrimFun (M.mapMaybe defPrim globals) e >>= \(primFun, args, result) -> - pure ApplyPrimFun{result, primFun, args, orig = e} + pure ApplyPrimFun{result, primFun, args, orig = e, tydefs, globals} -- (Λa.t : ∀b.T) S ~> (letType a = S in t) : (letType b = S in T) orig@(APP _ (Ann _ (LAM m a body) (TForall _ forallVar forallKind tgtTy)) argTy) -> pure @@ -1172,8 +1175,20 @@ runRedex opts = \case -- We should replace this with a proper exception. See: -- https://github.com/hackworthltd/primer/issues/148 | otherwise -> error "Internal Error: RenameBindingsCase found no applicable branches" - ApplyPrimFun{result, primFun, orig, args} -> do - expr' <- result + ApplyPrimFun{result, primFun, orig, args, tydefs, globals} -> do + -- TODO this can run forever - we haven't set a bound on number of steps + -- TODO `transformM` probably doesn't give us the right eval order - reuse existing machinery + expr' <- result $ fix $ \f -> transformM \e -> + maybe (pure e) (f . fst <=< runRedex opts) + =<< runMaybeT + ( flip runReaderT mempty + $ viewRedex + (ViewRedexOptions True True) -- TODO ? + tydefs + globals + Syn -- TODO ? + e + ) let details = ApplyPrimFunDetail { before = orig diff --git a/primer/src/Primer/Primitives.hs b/primer/src/Primer/Primitives.hs index 4afbe2bef..0db8eb5fc 100644 --- a/primer/src/Primer/Primitives.hs +++ b/primer/src/Primer/Primitives.hs @@ -40,6 +40,8 @@ import Data.Aeson (FromJSON (..), ToJSON (..)) import Data.ByteString.Base64 qualified as B64 import Data.Data (Data) import Data.Map qualified as M +import Data.Set (isSubsetOf) +import Data.Set qualified as Set import Diagrams.Backend.Rasterific ( Options (RasterificOptions), Rasterific (Rasterific), @@ -47,6 +49,7 @@ import Diagrams.Backend.Rasterific ( import Diagrams.Prelude ( Diagram, V2 (..), + blue, circle, deg, fillColor, @@ -57,6 +60,7 @@ import Diagrams.Prelude ( renderDia, rotate, sRGB24, + text, translate, (@@), ) @@ -78,9 +82,9 @@ import Primer.Core ( GVarName, GlobalName, ID, + LocalName (unLocalName), ModuleName, PrimCon (PrimAnimation, PrimChar, PrimInt), - TmVarRef (LocalVarRef), TyConName, Type' (..), ValConName, @@ -94,7 +98,7 @@ import Primer.Core.DSL ( prim, tcon, ) -import Primer.Core.Utils (generateIDs) +import Primer.Core.Utils (freeVars, generateIDs) import Primer.JSON (CustomJSON (..), PrimerJSON) import Primer.Name (Name) import Primer.Primitives.PrimDef (PrimDef (..)) @@ -236,48 +240,49 @@ primFunTypes = \case a = TApp () f = TFun () -primFunDef :: PrimDef -> [Expr' () () ()] -> Either PrimFunError (forall m. MonadFresh ID m => m Expr) +primFunDef :: PrimDef -> [Expr' () () ()] -> Either PrimFunError (forall m. MonadFresh ID m => (Expr -> m Expr) -> m Expr) primFunDef def args = case def of ToUpper -> case args of [PrimCon _ (PrimChar c)] -> - Right $ char $ toUpper c + Right $ const $ char $ toUpper c _ -> err IsSpace -> case args of [PrimCon _ (PrimChar c)] -> - Right $ boolAnn (isSpace c) + Right $ const $ boolAnn (isSpace c) _ -> err HexToNat -> case args of - [PrimCon _ (PrimChar c)] -> Right $ maybeAnn (tcon tNat) nat (digitToIntSafe c) + [PrimCon _ (PrimChar c)] -> Right $ const $ maybeAnn (tcon tNat) nat (digitToIntSafe c) where digitToIntSafe :: Char -> Maybe Natural digitToIntSafe c' = fromIntegral <$> (guard (isHexDigit c') $> digitToInt c') _ -> err NatToHex -> case args of [exprToNat -> Just n] -> - Right $ maybeAnn (tcon tChar) char $ intToDigitSafe n + Right $ const $ maybeAnn (tcon tChar) char $ intToDigitSafe n where intToDigitSafe :: Natural -> Maybe Char intToDigitSafe n' = guard (0 <= n && n <= 15) $> intToDigit (fromIntegral n') _ -> err EqChar -> case args of [PrimCon _ (PrimChar c1), PrimCon _ (PrimChar c2)] -> - Right $ boolAnn $ c1 == c2 + Right $ const $ boolAnn $ c1 == c2 _ -> err IntAdd -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ int $ x + y + Right $ const $ int $ x + y _ -> err IntMinus -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ int $ x - y + Right $ const $ int $ x - y _ -> err IntMul -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ int $ x * y + Right $ const $ int $ x * y _ -> err IntQuotient -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> Right + $ const $ maybeAnn (tcon tInt) int $ if y == 0 then Nothing @@ -286,6 +291,7 @@ primFunDef def args = case def of IntRemainder -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> Right + $ const $ maybeAnn (tcon tInt) int $ if y == 0 then Nothing @@ -294,12 +300,14 @@ primFunDef def args = case def of IntQuot -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> Right + $ const $ int $ if y == 0 then 0 else x `div` y _ -> err IntRem -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> Right + $ const $ int $ if y == 0 then x @@ -307,31 +315,32 @@ primFunDef def args = case def of _ -> err IntLT -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ boolAnn $ x < y + Right $ const $ boolAnn $ x < y _ -> err IntLTE -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ boolAnn $ x <= y + Right $ const $ boolAnn $ x <= y _ -> err IntGT -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ boolAnn $ x > y + Right $ const $ boolAnn $ x > y _ -> err IntGTE -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ boolAnn $ x >= y + Right $ const $ boolAnn $ x >= y _ -> err IntEq -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ boolAnn $ x == y + Right $ const $ boolAnn $ x == y _ -> err IntNeq -> case args of [PrimCon _ (PrimInt x), PrimCon _ (PrimInt y)] -> - Right $ boolAnn $ x /= y + Right $ const $ boolAnn $ x /= y _ -> err IntToNat -> case args of [PrimCon _ (PrimInt x)] -> Right + $ const $ maybeAnn (tcon tNat) nat $ if x < 0 then Nothing @@ -339,46 +348,40 @@ primFunDef def args = case def of _ -> err IntFromNat -> case args of [exprToNat -> Just n] -> - Right $ int $ fromIntegral n + Right $ const $ int $ fromIntegral n _ -> err Animate -> case args of -- Since we can only translate a `Picture` expression to an image once it is in normal form, -- this guard will only pass when `picture` has no free variables other than `time`. - [PrimCon () (PrimInt duration), Lam () time picture] - | Just (frames :: [Diagram Rasterific]) <- traverse diagramAtTime [0 .. duration * 100 `div` frameLength] -> - Right - $ prim - $ PrimAnimation - $ either - -- This case really shouldn't be able to happen, unless `diagrams-rasterific` is broken. - -- In fact, the default behaviour (`animatedGif`) is just to write the error to `stdout`, - -- and we only have to handle this because we need to use the lower-level `rasterGif`, - -- for unrelated reasons (getting the `Bytestring` without dumping it to a file). - mempty - (decodeUtf8 . B64.encode . toS) - $ encodeComplexGifImage - $ GifEncode (fromInteger width) (fromInteger height) Nothing Nothing gifLooping - $ flip palettizeWithAlpha DisposalRestoreBackground - $ map - ( (fromInteger frameLength,) - . renderDia - Rasterific - (RasterificOptions (mkSizeSpec $ Just . fromInteger <$> V2 width height)) - . rectEnvelope - (fromInteger <$> mkP2 (-width `div` 2) (-height `div` 2)) - (fromInteger <$> V2 width height) - ) - frames + [PrimCon () (PrimInt duration), Lam () time picture] | freeVars picture `isSubsetOf` Set.singleton (unLocalName time) -> Right \eval -> do + frames0 <- for [0 .. duration * 100 `div` frameLength] \t -> + -- TODO let the evaluator do the beta reduction as well? + fmap exprToDiagram . eval =<< generateIDs (Let () time (PrimCon () (PrimInt t)) picture) + -- TODO better error handling + let (frames :: [Diagram Rasterific]) = fromMaybe [text "error" <> (circle 40 & fillColor blue)] $ sequence frames0 + prim + $ PrimAnimation + $ either + -- This case really shouldn't be able to happen, unless `diagrams-rasterific` is broken. + -- In fact, the default behaviour (`animatedGif`) is just to write the error to `stdout`, + -- and we only have to handle this because we need to use the lower-level `rasterGif`, + -- for unrelated reasons (getting the `Bytestring` without dumping it to a file). + mempty + (decodeUtf8 . B64.encode . toS) + $ encodeComplexGifImage + $ GifEncode (fromInteger width) (fromInteger height) Nothing Nothing gifLooping + $ flip palettizeWithAlpha DisposalRestoreBackground + $ map + ( (fromInteger frameLength,) + . renderDia + Rasterific + (RasterificOptions (mkSizeSpec $ Just . fromInteger <$> V2 width height)) + . rectEnvelope + (fromInteger <$> mkP2 (-width `div` 2) (-height `div` 2)) + (fromInteger <$> V2 width height) + ) + frames where - -- Note that this simple substitution hack only allows for trivial functions, - -- i.e. those where only substitution is needed for the function body to reach a normal form. - -- Our primitives system doesn't yet support further evaluation here. - diagramAtTime t = exprToDiagram $ substTime (PrimCon () (PrimInt t)) picture - where - substTime a = \case - Var () (LocalVarRef t') | t' == time -> a - Con () c es -> Con () c $ map (substTime a) es - e -> e -- Values which are hardcoded, for now at least, for the sake of keeping the student-facing API simple. -- We keep the frame rate and resolution low to avoid serialising huge GIFs. gifLooping = LoopingForever @@ -388,7 +391,7 @@ primFunDef def args = case def of _ -> err PrimConst -> case args of [x, _] -> - Right $ generateIDs x `ann` tcon tBool + Right $ const $ generateIDs x `ann` tcon tBool _ -> err where exprToNat = \case