Skip to content

[AST] [Performance] Use 'SmallArray' instead of 'Vector' #7010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plutus-core/plutus-core/src/PlutusCore/Compiler/Erase.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module PlutusCore.Compiler.Erase (eraseTerm, eraseProgram) where

import Data.Vector (fromList)
import GHC.IsList (fromList)
import PlutusCore.Core
import PlutusCore.Name.Unique
import UntypedPlutusCore.Core qualified as UPLC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import PlutusCore.Rename.Monad
import Universe

import Data.Hashable
import Data.Vector qualified as V
import Data.Primitive.SmallArray (SmallArray)

instance (GEq uni, Closed uni, uni `Everywhere` Eq, Eq fun, Eq ann) =>
Eq (Term Name uni fun ann) where
Expand All @@ -37,7 +37,7 @@ type HashableTermConstraints uni fun ann =

-- This instance is the only logical one, and exists also in the package `vector-instances`.
-- Since this is the same implementation as that one, there isn't even much risk of incoherence.
instance Hashable a => Hashable (V.Vector a) where
instance Hashable a => Hashable (SmallArray a) where
hashWithSalt s = hashWithSalt s . toList

instance HashableTermConstraints uni fun ann => Hashable (Term Name uni fun ann)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ import UntypedPlutusCore.Core.Type

import Control.Lens
import Control.Monad
import Data.Vector qualified as V
import Flat
import Flat.Decoder
import Flat.Encoder
import Flat.Encoder.Strict (sizeListWith)
import GHC.IsList (fromList)
import Universe

{-
Expand Down Expand Up @@ -123,7 +123,7 @@ encodeTerm = \case
Error ann -> encodeTermTag 6 <> encode ann
Builtin ann bn -> encodeTermTag 7 <> encode ann <> encode bn
Constr ann i es -> encodeTermTag 8 <> encode ann <> encode i <> encodeListWith encodeTerm es
Case ann arg cs -> encodeTermTag 9 <> encode ann <> encodeTerm arg <> encodeListWith encodeTerm (V.toList cs)
Case ann arg cs -> encodeTermTag 9 <> encode ann <> encodeTerm arg <> encodeListWith encodeTerm (toList cs)

decodeTerm
:: forall name uni fun ann
Expand Down Expand Up @@ -160,7 +160,7 @@ decodeTerm version builtinPred = go
Constr <$> decode <*> decode <*> decodeListWith go
handleTerm 9 = do
unless (version >= PLC.plcVersion110) $ fail $ "'case' is not allowed before version 1.1.0, this program has version: " ++ (show $ pretty version)
Case <$> decode <*> go <*> (V.fromList <$> decodeListWith go)
Case <$> decode <*> go <*> (fromList <$> decodeListWith go)
handleTerm t = fail $ "Unknown term constructor tag: " ++ show t

sizeTerm
Expand Down Expand Up @@ -188,7 +188,7 @@ sizeTerm tm sz =
Error ann -> size ann sz'
Builtin ann bn -> size ann $ size bn sz'
Constr ann i es -> size ann $ size i $ sizeListWith sizeTerm es sz'
Case ann arg cs -> size ann $ sizeTerm arg $ sizeListWith sizeTerm (V.toList cs) sz'
Case ann arg cs -> size ann $ sizeTerm arg $ sizeListWith sizeTerm (toList cs) sz'

-- | An encoder for programs.
--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import PlutusCore.Name.Unique
import PlutusCore.Quote

import Data.Proxy
import Data.Vector qualified as Vector
import GHC.IsList (fromList)

firstBound :: Term name uni fun ann -> [name]
firstBound (Apply _ (LamAbs _ name body) _) = name : firstBound body
Expand All @@ -41,15 +41,15 @@ instance name ~ Name => EstablishScoping (Term name uni fun) where
establishScoping (Constr _ i es) = Constr NotAName <$> pure i <*> traverse establishScoping es
establishScoping (Case _ a es) = do
esScoped <- traverse establishScoping es
let esScopedPoked = addTheRest . map (\e -> (e, firstBound e)) $ Vector.toList esScoped
let esScopedPoked = addTheRest . map (\e -> (e, firstBound e)) $ toList esScoped
branchBounds = map (snd . fst) esScopedPoked
referenceInBranch ((branch, _), others) = referenceOutOfScope (map snd others) branch
aScoped <- establishScoping a
-- For each of the branches reference (as out-of-scope) the variables bound in that branch
-- in all the other ones, as well as outside of the whole case-expression. This is to check
-- that none of the transformations leak variables outside of the branch they're bound in.
pure . referenceOutOfScope branchBounds $
Case NotAName aScoped . Vector.fromList $ map referenceInBranch esScopedPoked
Case NotAName aScoped . fromList $ map referenceInBranch esScopedPoked

instance name ~ Name => EstablishScoping (Program name uni fun) where
establishScoping (Program _ ver term) = Program NotAName ver <$> establishScoping term
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ module UntypedPlutusCore.Core.Type
import Control.Lens
import PlutusPrelude

import Data.Vector
import Data.Primitive.SmallArray (SmallArray)
import Data.Word
import GHC.IsList (fromList)
import PlutusCore.Builtin qualified as TPLC
import PlutusCore.Core qualified as TPLC
import PlutusCore.MkPlc
Expand Down Expand Up @@ -86,7 +87,7 @@ data Term name uni fun ann
-- TODO: try spine-strict list or strict list or vector
-- See Note [Constr tag type]
| Constr !ann !Word64 ![Term name uni fun ann]
| Case !ann !(Term name uni fun ann) !(Vector (Term name uni fun ann))
| Case !ann !(Term name uni fun ann) {-# UNPACK #-} !(SmallArray (Term name uni fun ann))
deriving stock (Functor, Generic)

deriving stock instance (Show name, GShow uni, Everywhere uni Show, Show fun, Show ann, Closed uni)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module UntypedPlutusCore.Core.Zip

import Control.Monad (void, when)
import Control.Monad.Except (MonadError, throwError)
import Data.Vector
import GHC.IsList (fromList, toList)
import UntypedPlutusCore.Core.Instance.Eq ()
import UntypedPlutusCore.Core.Type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ import Data.DList qualified as DList
import Data.Functor.Identity
import Data.Hashable (Hashable)
import Data.Kind qualified as GHC
import Data.Primitive.SmallArray (SmallArray, indexSmallArray, sizeofSmallArray)
import Data.Proxy
import Data.Semigroup (stimes)
import Data.Text (Text)
import Data.Vector qualified as V
import Data.Word
import GHC.Generics
import GHC.TypeLits
import Prettyprinter
import Universe
import Unsafe.Coerce (unsafeCoerce)

{- Note [Compilation peculiarities]
READ THIS BEFORE TOUCHING ANYTHING IN THIS FILE
Expand Down Expand Up @@ -600,7 +601,7 @@ data Context uni fun ann
-- See Note [Accumulators for terms]
| FrameConstr !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 ![NTerm uni fun ann] !(ArgStack uni fun ann) !(Context uni fun ann)
-- ^ @(constr i V0 ... Vj-1 _ Nj ... Nn)@
| FrameCases !(CekValEnv uni fun ann) !(V.Vector (NTerm uni fun ann)) !(Context uni fun ann)
| FrameCases !(CekValEnv uni fun ann) {-# UNPACK #-} !(SmallArray (NTerm uni fun ann)) !(Context uni fun ann)
-- ^ @(case _ C0 .. Cn)@
| NoFrame

Expand Down Expand Up @@ -775,16 +776,11 @@ enterComputeCek = computeCek
[] -> returnCek ctx $ VConstr i done'
-- s , case _ (C0 ... CN, ρ) ◅ constr i V1 .. Vm ↦ s , [_ V1 ... Vm] ; ρ ▻ Ci
returnCek (FrameCases env cs ctx) e = case e of
-- If the index is larger than the max bound of an Int, or negative, then it's a bad index
-- As it happens, this will currently never trigger, since i is a Word64, and the largest
-- Word64 value wraps to -1 as an Int64. So you can't wrap around enough to get an
-- "apparently good" value.
(VConstr i _) | fromIntegral @_ @Integer i > fromIntegral @Int @Integer maxBound ->
throwingDischarged _MachineError (MissingCaseBranchMachineError i) e
-- Otherwise, we can safely convert the index to an Int and use it
(VConstr i args) -> case (V.!?) cs (fromIntegral i) of
Just t -> computeCek (transferArgStack args ctx) env t
Nothing -> throwingDischarged _MachineError (MissingCaseBranchMachineError i) e
VConstr i args
| i < unsafeCoerce (sizeofSmallArray cs) ->
computeCek (transferArgStack args ctx) env . indexSmallArray cs $ unsafeCoerce i
| otherwise ->
throwingDischarged _MachineError (MissingCaseBranchMachineError i) e
_ -> throwingDischarged _MachineError NonConstrScrutinizedMachineError e

-- | Evaluate a 'HeadSpine' by pushing the arguments (if any) onto the stack and proceeding with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ import UntypedPlutusCore.Evaluation.Machine.Cek.StepCounter
import Control.Lens hiding (Context)
import Control.Monad
import Control.Monad.Primitive
import Data.Primitive.SmallArray (SmallArray)
import Data.Proxy
import Data.RandomAccessList.Class qualified as Env
import Data.Semigroup (stimes)
import Data.Text (Text)
import Data.Vector qualified as V
import Data.Word (Word64)
import GHC.TypeNats
import Universe
Expand Down Expand Up @@ -100,7 +100,7 @@ data Context uni fun ann
| FrameAwaitFunValue ann !(CekValue uni fun ann) !(Context uni fun ann)
| FrameForce ann !(Context uni fun ann) -- ^ @(force _)@
| FrameConstr ann !(CekValEnv uni fun ann) {-# UNPACK #-} !Word64 ![NTerm uni fun ann] !(ArgStack uni fun ann) !(Context uni fun ann)
| FrameCases ann !(CekValEnv uni fun ann) !(V.Vector (NTerm uni fun ann)) !(Context uni fun ann)
| FrameCases ann !(CekValEnv uni fun ann) !(SmallArray (NTerm uni fun ann)) !(Context uni fun ann)
| NoFrame

deriving stock instance (GShow uni, Everywhere uni Show, Show fun, Show ann, Closed uni)
Expand Down Expand Up @@ -206,7 +206,7 @@ returnCek (FrameCases ann env cs ctx) e = case e of
-- "apparently good" value.
(VConstr i _) | fromIntegral @_ @Integer i > fromIntegral @Int @Integer maxBound ->
throwingDischarged _MachineError (MissingCaseBranchMachineError i) e
(VConstr i args) -> case (V.!?) cs (fromIntegral i) of
(VConstr i args) -> case toList cs ^? ix (fromIntegral i) of
Just t ->
let ctx' = transferArgStack ann args ctx
in computeCek ctx' env t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import UntypedPlutusCore.Core.Type qualified as UPLC
import UntypedPlutusCore.Rename (Rename (rename))

import Data.Text (Text)
import Data.Vector qualified as V
import GHC.IsList (fromList)
import PlutusCore.Error (AsParserErrorBundle)
import PlutusCore.MkPlc (mkIterApp)
import PlutusCore.Parser hiding (parseProgram, parseTerm, program)
Expand Down Expand Up @@ -82,7 +82,7 @@ constrTerm = withSpan $ \sp ->
caseTerm :: Parser PTerm
caseTerm = withSpan $ \sp ->
inParens $ do
res <- UPLC.Case sp <$> (symbol "case" *> term) <*> (V.fromList <$> many term)
res <- UPLC.Case sp <$> (symbol "case" *> term) <*> (fromList <$> many term)
whenVersion (\v -> v < plcVersion110) $ fail "'case' is not allowed before version 1.1.0"
pure res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import UntypedPlutusCore.Core
import UntypedPlutusCore.Transform.Simplifier (SimplifierStage (CaseReduce), SimplifierT,
recordSimplification)

import Control.Lens (transformOf)
import Data.Vector qualified as V
import Control.Lens (ix, transformOf, (^?))
import Data.Foldable (toList)

caseReduce
:: Monad m
Expand All @@ -23,6 +23,6 @@ caseReduce term = do

processTerm :: Term name uni fun a -> Term name uni fun a
processTerm = \case
Case ann (Constr _ i args) cs | Just c <- (V.!?) cs (fromIntegral i) ->
Case ann (Constr _ i args) cs | Just c <- toList cs ^? ix (fromIntegral i) ->
mkIterApp c ((ann,) <$> args)
t -> t
5 changes: 2 additions & 3 deletions plutus-core/untyped-plutus-core/test/Generators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
-- | UPLC property tests (pretty-printing\/parsing and binary encoding\/decoding).
module Generators where

import PlutusPrelude (display, fold, on, void, zipExact, (&&&))
import PlutusPrelude (display, fold, on, toList, void, zipExact, (&&&))

import PlutusCore (Name, _nameText)
import PlutusCore.Annotation
Expand All @@ -28,7 +28,6 @@ import UntypedPlutusCore.Parser (parseProgram, parseTerm)
import Control.Lens (view)
import Data.Text (Text)
import Data.Text qualified as T
import Data.Vector qualified as V

import Hedgehog (annotate, annotateShow, failure, property, tripping, (===))
import Hedgehog.Gen qualified as Gen
Expand Down Expand Up @@ -61,7 +60,7 @@ compareTerm (Delay _ t ) (Delay _ t') = compareTerm t t'
compareTerm (Constant _ x) (Constant _ y) = x == y
compareTerm (Builtin _ bi) (Builtin _ bi') = bi == bi'
compareTerm (Constr _ i es) (Constr _ i' es') = i == i' && maybe False (all (uncurry compareTerm)) (zipExact es es')
compareTerm (Case _ arg cs) (Case _ arg' cs') = compareTerm arg arg' && maybe False (all (uncurry compareTerm)) (zipExact (V.toList cs) (V.toList cs'))
compareTerm (Case _ arg cs) (Case _ arg' cs') = compareTerm arg arg' && maybe False (all (uncurry compareTerm)) (zipExact (toList cs) (toList cs'))
compareTerm (Error _ ) (Error _ ) = True
compareTerm _ _ = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module Transform.CaseOfCase.Test where

import Data.ByteString.Lazy qualified as BSL
import Data.Text.Encoding (encodeUtf8)
import Data.Vector qualified as V
import GHC.IsList (fromList)
import PlutusCore qualified as PLC
import PlutusCore.Evaluation.Machine.BuiltinCostModel (BuiltinCostModel)
import PlutusCore.Evaluation.Machine.ExBudgetingDefaults (defaultBuiltinCostModelForTesting,
Expand Down Expand Up @@ -45,7 +45,7 @@ caseOfCase1 = runQuote do
let ite = Force () (Builtin () PLC.IfThenElse)
true = Constr () 0 []
false = Constr () 1 []
alts = V.fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
alts = fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- | This should not simplify, because one of the branches of `ifThenElse` is not a `Constr`.
Expand All @@ -59,7 +59,7 @@ caseOfCase2 = runQuote do
let ite = Force () (Builtin () PLC.IfThenElse)
true = Var () t
false = Constr () 1 []
alts = V.fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
alts = fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- | Similar to `caseOfCase1`, but the type of the @true@ and @false@ branches is
Expand All @@ -76,7 +76,7 @@ caseOfCase3 = runQuote do
false = Constr () 1 []
altTrue = Var () f
altFalse = mkConstant @Integer () 2
alts = V.fromList [altTrue, altFalse]
alts = fromList [altTrue, altFalse]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- |
Expand Down Expand Up @@ -107,7 +107,7 @@ caseOfCaseWithError =
, ((), Constr () 1 []) -- False
]
)
(V.fromList [mkConstant @() () (), Error ()])
(fromList [mkConstant @() () (), Error ()])

testCaseOfCaseWithError :: TestTree
testCaseOfCaseWithError =
Expand Down
10 changes: 5 additions & 5 deletions plutus-core/untyped-plutus-core/test/Transform/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
module Transform.Simplify where

import Data.Text (Text)
import Data.Vector qualified as V
import GHC.IsList (fromList)
import PlutusCore qualified as PLC
import PlutusCore.MkPlc (mkConstant, mkIterApp, mkIterAppNoAnn)
import PlutusCore.Quote (Quote, freshName, runQuote)
Expand Down Expand Up @@ -34,7 +34,7 @@ caseOfCase1 = runQuote $ do
let ite = Force () (Builtin () PLC.IfThenElse)
true = Constr () 0 []
false = Constr () 1 []
alts = V.fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
alts = fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- | This should not simplify, because one of the branches of `ifThenElse` is not a `Constr`.
Expand All @@ -48,7 +48,7 @@ caseOfCase2 = runQuote $ do
let ite = Force () (Builtin () PLC.IfThenElse)
true = Var () t
false = Constr () 1 []
alts = V.fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
alts = fromList [mkConstant @Integer () 1, mkConstant @Integer () 2]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

{- | Similar to `caseOfCase1`, but the type of the @true@ and @false@ branches is
Expand All @@ -65,7 +65,7 @@ caseOfCase3 = runQuote $ do
false = Constr () 1 []
altTrue = Var () f
altFalse = mkConstant @Integer () 2
alts = V.fromList [altTrue, altFalse]
alts = fromList [altTrue, altFalse]
pure $ Case () (mkIterApp ite [((), Var () b), ((), true), ((), false)]) alts

-- | The `Delay` should be floated into the lambda.
Expand Down Expand Up @@ -408,7 +408,7 @@ cse1 = runQuote $ do
branch1 = plus onePlusTwoPlusX threePlusX
branch2 = plus twoPlusX threePlusX
branch3 = fourPlusX
caseExpr = Case () (Var () y) (V.fromList [branch1, branch2, branch3])
caseExpr = Case () (Var () y) (fromList [branch1, branch2, branch3])
pure $ LamAbs () x (LamAbs () y body)

-- | This is the second example in Note [CSE].
Expand Down
Loading