diff --git a/readme.org b/readme.org index 57063ef..5a7e64d 100644 --- a/readme.org +++ b/readme.org @@ -744,21 +744,8 @@ import Data.Vector qualified as V #+end_src #+RESULTS: -#+begin_example -Build profile: -w ghc-9.8.2 -O1 -In order, the following will be built (use -v for more details): - - numhask-array-0.12 (lib) (file src/NumHask/Array/Fixed.hs changed) -Preprocessing library for numhask-array-0.12.. -GHCi, version 9.8.2: https://www.haskell.org/ghc/ :? for help -Loaded GHCi configuration from /Users/tonyday567/haskell/numhask-array/.ghci -[1 of 5] Compiling NumHask.Array.Shape ( src/NumHask/Array/Shape.hs, interpreted ) -[2 of 5] Compiling NumHask.Array.Sort ( src/NumHask/Array/Sort.hs, interpreted ) -[3 of 5] Compiling NumHask.Array.Dynamic ( src/NumHask/Array/Dynamic.hs, interpreted ) -[4 of 5] Compiling NumHask.Array.Fixed ( src/NumHask/Array/Fixed.hs, interpreted ) -[5 of 5] Compiling NumHask.Array ( src/NumHask/Array.hs, interpreted ) -Ok, five modules loaded. -Ok, five modules loaded. -#+end_example +: [4 of 5] Compiling NumHask.Array.Fixed ( src/NumHask/Array/Fixed.hs, interpreted ) [Source file changed] +: Ok, five modules loaded. #+begin_src haskell-ng a = range @[2,3,4] @@ -1505,9 +1492,10 @@ https://github.com/goldfirere/singletons/blob/master/README.md ** dependent type examples #+begin_src haskell-ng :results output +:{ example_inserta :: (Show a, FromInteger a) => SomeArray a -> String -example_inserta (SomeArray (SNats :: SNats ns) a) = show (insert (SNat @0) 0 a (toScalar 0)) --- Could not deduce ‘KnownNats +example_inserta (SomeArray (SNats :: SNats ns) a) = show (insert (SNat @0) (SNat @0) a (toScalar 0)) +:} -- (If (s Data.Type.Equality.== '[]) '[1] s)’ -- arising from a use of ‘insert’ #+end_src @@ -1515,6 +1503,11 @@ example_inserta (SomeArray (SNats :: SNats ns) a) = show (insert (SNat @0) 0 a ( segfaults as SNats is somehow SNat @'[] #+begin_src haskell-ng :results output +:{ +import Fcf (Eval) +someTakeDim :: forall d t s. (KnownNats s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) +someTakeDim = withSomeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) unsafeCoerce + example_take :: forall a s. (KnownNats s, Show a) => Nat -> Nat -> Array s a -> String example_take d t a = withSomeSNat d @@ -1522,10 +1515,17 @@ example_take d t a = withSomeSNat t (\(SNat :: SNat t) -> case someTakeDim @d @t @s of - SNats -> show $ take (SNat @d) (SNat @t) a)) + SNats -> show $ F.take (SNat @d) (SNat @t) a)) +:} #+end_src +but this is ok + #+begin_src haskell-ng :results output +:{ +someTakeDim2 :: forall d t s. (KnownNats s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) +someTakeDim2 = UnsafeSNats (fromIntegral <$> takeDim (int @d) (int @t) (ints @s)) + example_take' :: forall a s. (KnownNats s, Show a) => Nat -> Nat -> Array s a -> String example_take' d t a = withSomeSNat d @@ -1534,64 +1534,6 @@ example_take' d t a = (\(SNat :: SNat t) -> case someTakeDim2 @d' @t @s of x -> show x <> " " <> show (someTakeDim2 @d' @t @s) <> " : take " <> show (SNat @d') <> " " <> show (SNat @t) <> " " <> show a)) -#+end_src - -example_take' 0 1 (range @[2,3,4]) -"SNats @[1, 3, 4] SNats @[1, 3, 4] : take SNat @0 SNat @1 [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]" - -vector examples - -#+begin_src haskell-ng :results output -data SomeVector a where - SomeVector :: KnownNat n => Vector n a -> SomeVector a - -deriving instance (Show a) => Show (SomeVector a) - -withLength :: forall n a r. Vector n a -> (KnownNat n => r) -> r -withLength v r = case someNatVal (fromIntegral $ V.length (asVector v)) of - SomeNat (Proxy :: Proxy n') -> case unsafeCoerce Refl of - (Refl :: n :~: n') -> r - -aVector :: FromVector t a => t -> SomeVector a -aVector (Array . asVector -> v) = withLength v (SomeVector v) -example_append :: (Show a, Num a, FromInteger a) => SomeVector a -> String -example_append (SomeVector a) = show (append (SNat @0) a (toScalar 0)) - -example_insert :: (Show a, FromInteger a) => SomeVector a -> String -example_insert (SomeVector a) = show (insert (SNat @0) 0 a (toScalar 0)) - -data SomeVector' a = forall n. SomeVector' (SNat n) (Vector n a) - -deriving instance (Show a) => Show (SomeVector' a) - -someVector' :: FromVector t a => KnownNat n => SNat n -> t -> SomeVector' a -someVector' n t = SomeVector' n (vector' n t) - -aVector' :: forall a t. FromVector t a => t -> SomeVector' a -aVector' t = withSomeSNat (fromIntegral $ V.length (asVector t)) $ \(SNat :: SNat n) -> SomeVector' SNat (vector' (SNat @n) (asVector t)) - -example_insert' :: (Show a, FromInteger a) => SomeVector' a -> String -example_insert' (SomeVector' (SNat :: SNat n) a) = show (insert (SNat @0) 0 a (toScalar 0)) - -example_append' :: (Show a, Num a, FromInteger a) => SomeVector' a -> String -example_append' (SomeVector' (SNat :: SNat n) a) = show (append (SNat @0) a (toScalar 0)) - -instance (Arbitrary a) => Arbitrary (SomeVector' a) where - arbitrary = do - n <- arbitrary - v <- V.replicateM (Prelude.fromIntegral n) arbitrary - withSomeNat n $ \sn -> pure (someVector' sn v) -#+end_src - - -#+begin_src haskell-ng :results output -someTakeDim :: forall d t s. (KnownNats s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) -someTakeDim = withSomeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) unsafeCoerce - -someTakeDim2 :: forall d t s. (KnownNats s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) -someTakeDim2 = UnsafeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) - -someUnsafeGetIndex :: forall d s. (KnownNats s, KnownNat d) => SNat (Eval (UnsafeGetIndex d s)) -someUnsafeGetIndex = withSomeSNat (fromIntegral $ getDim (int (SNat :: SNat d)) (ints (SNats :: SNats s))) unsafeCoerce +:} #+end_src diff --git a/src/NumHask/Array/Fixed.hs b/src/NumHask/Array/Fixed.hs index 9790b78..a199f71 100644 --- a/src/NumHask/Array/Fixed.hs +++ b/src/NumHask/Array/Fixed.hs @@ -83,15 +83,15 @@ module NumHask.Array.Fixed drop, dropB, select, - concatenate, insert, delete, append, prepend, + concatenate, couple, slice, - -- * Operators + -- * Multi-dimension Operators takes, takeBs, drops, @@ -191,7 +191,6 @@ import Data.Functor.Classes import Data.Functor.Rep import Data.Vector qualified as V import Fcf hiding (type (&&), type (+), type (-), type (++)) -import Fcf qualified import Fcf.Data.List import GHC.TypeNats import NumHask.Array.Dynamic qualified as D @@ -478,7 +477,6 @@ instance Foldable SomeArray where foldMap f (SomeArray _ a) = foldMap f a - someArray :: forall s t a. FromVector t a => SNats s -> t -> SomeArray a someArray s t = SomeArray s (Array (asVector t)) @@ -756,20 +754,20 @@ imap :: Array s b imap f a = zipWith f indices a --- | Apply a function that takes dimensions and a SNat parameter and applies a parameter list to the initial dimensions. ie +-- | Apply a function that takes dimensions and (type-level) parameters and applies a parameters to the initial dimensions. ie -- -- > rowWise f xs = f [0..rank xs - 1] xs -- -- >>> toDynamic $ rowWise indexesT (S.SNats @[1,0]) a -- UnsafeArray [4] [12,13,14,15] rowWise :: - forall a ds s s' xs. + forall a ds s s' xs proxy. ( KnownNats s , KnownNats ds , ds ~ Eval (DimsOf xs) ) => - (SNats ds -> SNats xs -> Array s a -> Array s' a) -> - SNats xs -> Array s a -> Array s' a + (Dims ds -> proxy xs -> Array s a -> Array s' a) -> + proxy xs -> Array s a -> Array s' a rowWise f xs a = f (SNats @ds) xs a -- | Apply a function that takes a (dimension,parameter) list and applies a parameter list to the the last dimensions (in reverse). ie @@ -779,12 +777,12 @@ rowWise f xs a = f (SNats @ds) xs a -- >>> toDynamic $ colWise indexesT (S.SNats @[1,0]) a -- UnsafeArray [2] [1,13] colWise :: - forall a ds s s' xs. + forall a ds s s' xs proxy. ( KnownNats s , KnownNats ds , ds ~ Eval (EndDimsOf xs s)) => - (SNats ds -> SNats xs -> Array s a -> Array s' a) -> - SNats xs -> Array s a -> Array s' a + (Dims ds -> proxy xs -> Array s a -> Array s' a) -> + proxy xs -> Array s a -> Array s' a colWise f xs a = f (SNats @ds) xs a -- | Take the top-most elements across the specified dimension. @@ -802,7 +800,7 @@ take :: KnownNats s', s' ~ Eval (TakeDim d t s) ) => - SNat d -> + Dim d -> SNat t -> Array s a -> Array s' a @@ -823,7 +821,7 @@ takeB :: KnownNats s', s' ~ Eval (TakeDim d t s) ) => - SNat d -> + Dim d -> SNat t -> Array s a -> Array s' a @@ -844,7 +842,7 @@ drop :: KnownNats s', Eval (DropDim d t s) ~ s' ) => - SNat d -> + Dim d -> SNat t -> Array s a -> Array s' a @@ -865,7 +863,7 @@ dropB :: KnownNats s', Eval (DropDim d t s) ~ s' ) => - SNat d -> + Dim d -> SNat t -> Array s a -> Array s' a @@ -873,56 +871,22 @@ dropB _ _ a = unsafeBackpermute id a -- | Select an index along a dimension. -- --- >>> let s = select (SNat @2) (SNat @3) a +-- >>> let s = select (SNat @2) (S.fin @4 3) a -- >>> pretty s -- [[3,7,11], -- [15,19,23]] select :: - forall d x a s s'. + forall d a p s s'. (KnownNats s, KnownNats s', - s' ~ Eval (DeleteDim d s)) => - SNat d -> - SNat x -> + s' ~ Eval (DeleteDim d s), + p ~ Eval (GetDim d s) + ) => + Dim d -> + Fin p -> Array s a -> Array s' a -select SNat SNat a = unsafeBackpermute (S.insertDim (valueOf @d) (valueOf @x)) a - --- | Concatenate along a dimension. --- --- >>> shape $ concatenate (SNat @1) a a --- [2,6,4] --- >>> toDynamic $ concatenate (SNat @0) (toScalar 1) (toScalar 2) --- UnsafeArray [2] [1,2] --- >>> toDynamic $ concatenate (SNat @0) (array @'[1] [0]) (array @'[3] [1..3]) --- UnsafeArray [4] [0,1,2,3] -concatenate :: - forall a s0 s1 d s. - ( KnownNats s0, - KnownNats s1, - KnownNats s, - Eval (Concatenate d s0 s1) ~ s - ) => - SNat d -> - Array s0 a -> - Array s1 a -> - Array s a -concatenate SNat a0 a1 = tabulate (go . fromFins) - where - go s = - bool - (index a0 (UnsafeFins s)) - ( index - a1 - ( UnsafeFins $ insertDim - d' - (getDim d' s - getDim d' ds0) - (deleteDim d' s) - ) - ) - (getDim d' s >= getDim d' ds0) - ds0 = shape a0 - d' = valueOf @d +select SNat p a = unsafeBackpermute (S.insertDim (valueOf @d) (fromFin p)) a -- | Insert along a dimension at a position. -- @@ -944,7 +908,7 @@ insert :: p ~ Eval (GetDim d s), True ~ Eval (InsertOk d s si) ) => - SNat d -> + Dim d -> Fin p -> Array s a -> Array si a -> @@ -973,7 +937,7 @@ delete :: KnownNats s', s' ~ Eval (DecAt d s), p ~ 1 + Eval (GetDim d s)) => - SNat d -> + Dim d -> Fin p -> Array s a -> Array s' a @@ -998,7 +962,7 @@ append :: s' ~ Eval (IncAt d s), True ~ Eval (InsertOk d s si) ) => - SNat d -> + Dim d -> Array s a -> Array si a -> Array s' a @@ -1021,12 +985,48 @@ prepend :: s' ~ Eval (IncAt d s), True ~ Eval (InsertOk d s si) ) => - SNat d -> + Dim d -> Array si a -> Array s a -> Array s' a prepend d a b = insert d (UnsafeFin 0) b a +-- | Concatenate along a dimension. +-- +-- >>> shape $ concatenate (SNat @1) a a +-- [2,6,4] +-- >>> toDynamic $ concatenate (SNat @0) (toScalar 1) (toScalar 2) +-- UnsafeArray [2] [1,2] +-- >>> toDynamic $ concatenate (SNat @0) (array @'[1] [0]) (array @'[3] [1..3]) +-- UnsafeArray [4] [0,1,2,3] +concatenate :: + forall a s0 s1 d s. + ( KnownNats s0, + KnownNats s1, + KnownNats s, + Eval (Concatenate d s0 s1) ~ s + ) => + Dim d -> + Array s0 a -> + Array s1 a -> + Array s a +concatenate SNat a0 a1 = tabulate (go . fromFins) + where + go s = + bool + (index a0 (UnsafeFins s)) + ( index + a1 + ( UnsafeFins $ insertDim + d' + (getDim d' s - getDim d' ds0) + (deleteDim d' s) + ) + ) + (getDim d' s >= getDim d' ds0) + ds0 = shape a0 + d' = valueOf @d + -- | Combine two arrays as rows of a new array. -- -- >>> pretty $ couple (array @'[3] [1,2,3]) (array @'[3] @Int [4,5,6]) @@ -1059,7 +1059,7 @@ slice :: KnownNats s', s' ~ Eval (SetDim d l s), Eval (SliceOk d off l s) ~ True) => - SNat d -> + Dim d -> SNat off -> SNat l -> Array s a -> @@ -1079,7 +1079,7 @@ takes :: KnownNats s', s' ~ Eval (SetDims ds xs s) ) => - SNats ds -> + Dims ds -> SNats xs -> Array s a -> Array s' a @@ -1098,7 +1098,7 @@ takeBs :: KnownNats xs, s' ~ Eval (SetDims ds xs s) ) => - SNats ds -> + Dims ds -> SNats xs -> Array s a -> Array s' a @@ -1120,7 +1120,7 @@ drops :: KnownNats xs, s' ~ Eval (DropDims ds xs s) ) => - SNats ds -> + Dims ds -> SNats xs -> Array s a -> Array s' a @@ -1142,7 +1142,7 @@ dropBs :: KnownNats xs, s' ~ Eval (DropDims ds xs s) ) => - SNats ds -> + Dims ds -> SNats xs -> Array s a -> Array s' a @@ -1159,7 +1159,7 @@ indexes :: s' ~ Eval (DeleteDims ds s), ts ~ Eval (GetDims ds s) ) => - SNats ds -> + Dims ds -> Fins ts -> Array s a -> Array s' a @@ -1175,9 +1175,10 @@ indexesT :: KnownNats ds, KnownNats xs, KnownNats s', - s' ~ Eval (DeleteDims ds s) + s' ~ Eval (DeleteDims ds s), + True ~ Eval (IsFins xs =<< GetDims ds s) ) => - SNats ds -> + Dims ds -> SNats xs -> Array s a -> Array s' a @@ -1200,7 +1201,7 @@ indexesExcept :: s' ~ Eval (GetDims ds s), ts ~ Eval (DeleteDims ds s) ) => - SNats ds -> + Dims ds -> Fins ts -> Array s a -> Array s' a @@ -1224,7 +1225,7 @@ slices :: KnownNats offs, Eval (SlicesOk ds offs ls s) ~ True, Eval (SetDims ds ls s) ~ s') => - SNats ds -> + Dims ds -> SNats offs -> SNats ls -> Array s a -> @@ -1243,7 +1244,7 @@ heads :: KnownNats s', KnownNats ds, s' ~ Eval (DeleteDims ds s)) => - SNats ds -> + Dims ds -> Array s a -> Array s' a heads ds a = indexes ds (UnsafeFins $ replicate (rankOf @s) zero) a @@ -1259,7 +1260,7 @@ lasts :: KnownNats s', s' ~ Eval (DeleteDims ds s) ) => - SNats ds -> + Dims ds -> Array s a -> Array s' a lasts ds a = indexes ds (UnsafeFins lastxs) a @@ -1281,10 +1282,10 @@ tails :: KnownNats os, Eval (SlicesOk ds os ls s) ~ True, os ~ Eval (Replicate (Eval (Rank ds)) 1), - ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))), + ls ~ Eval (GetLastPositions ds s), s' ~ Eval (SetDims ds ls s) ) => - SNats ds -> + Dims ds -> Array s a -> Array s' a tails ds a = slices ds (SNats @os) (SNats @ls) a @@ -1304,10 +1305,10 @@ inits :: KnownNats os, Eval (SlicesOk ds os ls s) ~ True, os ~ Eval (Replicate (Eval (Rank ds)) 0), - ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))), + ls ~ Eval (GetLastPositions ds s), s' ~ Eval (SetDims ds ls s) ) => - SNats ds -> + Dims ds -> Array s a -> Array s' a inits ds a = slices ds (SNats @os) (SNats @ls) a @@ -1327,7 +1328,7 @@ extracts :: si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st) ) => - SNats ds -> + Dims ds -> Array st a -> Array so (Array si a) extracts ds a = tabulate (\s -> indexes ds s a) @@ -1347,7 +1348,7 @@ extractsExcept :: so ~ Eval (DeleteDims ds st), si ~ Eval (GetDims ds st) ) => - SNats ds -> + Dims ds -> Array st a -> Array so (Array si a) extractsExcept ds a = tabulate (\s -> indexesExcept ds s a) @@ -1369,7 +1370,7 @@ reduces :: si ~ Eval (DeleteDims ds st), so ~ Eval (GetDims ds st) ) => - SNats ds -> + Dims ds -> (Array si a -> b) -> Array st a -> Array so b @@ -1388,7 +1389,7 @@ joins :: KnownNats si, KnownNats so, Eval (InsertDims ds so si) ~ st) => - SNats ds -> + Dims ds -> Array so (Array si a) -> Array st a joins _ a = tabulate go @@ -1424,7 +1425,7 @@ traverses :: KnownNats (Eval (DeleteDims ds s)), KnownNats (Eval (GetDims ds s)) ) => - SNats ds -> + Dims ds -> (a -> f b) -> Array s a -> f (Array s b) @@ -1447,7 +1448,7 @@ maps :: s ~ Eval (InsertDims ds so si) ) => (Array si a -> Array si' b) -> - SNats ds -> + Dims ds -> Array s a -> Array s' b maps f SNats a = joins (SNats @ds) (fmapRep f (extracts (SNats @ds) a)) @@ -1463,7 +1464,7 @@ filters :: si ~ Eval (DeleteDims ds so), KnownNats (Eval (GetDims ds so)) ) => - SNats ds -> + Dims ds -> (Array si a -> Bool) -> Array so a -> D.Array (Array si a) @@ -1490,7 +1491,7 @@ zips :: s' ~ Eval (InsertDims ds so si'), s ~ Eval (InsertDims ds so si) ) => - SNats ds -> + Dims ds -> (Array si a -> Array si b -> Array si' c) -> Array s a -> Array s b -> @@ -1515,7 +1516,7 @@ modifies :: so ~ Eval (GetDims ds s), s ~ Eval (InsertDims ds so si)) => (Array si a -> Array si a) -> - SNats ds -> + Dims ds -> Fins so -> Array s a -> Array s a @@ -1543,7 +1544,7 @@ diffs :: postDrop ~ Eval (InsertDims ds so si), postDrop ~ Eval (DropDims ds ls st) ) => - SNats ds -> + Dims ds -> SNats ls -> (Array si a -> Array si a -> Array si' b) -> Array st a -> Array st' b diffs SNats xs f a = zips (SNats @ds) f (drops (SNats @ds) xs a) (dropBs (SNats @ds) xs a) @@ -1622,7 +1623,7 @@ expandr f a b = tabulate (\i -> f (index a (UnsafeFins $ List.drop r (fromFins i -- FIXME: relook at expand/contract structure -- -- > let b = array [1..6] :: Array [2,3] Int --- > pretty $ contract sum [1,2] (expand (*) b (transpose b)) +-- > pretty $ contract (SNats @[1,2]) sum (expand (*) b (transpose b)) -- [[14,32], -- [32,77]] contract :: @@ -1635,11 +1636,11 @@ contract :: s' ~ Eval (DeleteDims ds s), ss ~ '[Eval (Minimum (Eval (GetDims ds s)))] ) => + Dims ds -> (Array ss a -> b) -> - SNats ds -> Array s a -> Array s' b -contract f SNats a = f . diag <$> extractsExcept (SNats @ds) a +contract SNats f a = f . diag <$> extractsExcept (SNats @ds) a -- | A generalisation of a dot operation, which is a multiplicative expansion of two arrays and sum contraction along the middle two dimensions. -- @@ -1686,7 +1687,7 @@ dot :: Array sa a -> Array sb b -> Array s' d -dot f g a b = contract f (SNats :: SNats x) (expand g a b) +dot f g a b = contract (SNats :: SNats x) f (expand g a b) -- | Array multiplication. -- @@ -1766,7 +1767,7 @@ find :: ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s), i' ~ Eval (Rerank r si), - re ~ Eval ((Fcf.++) (Eval (Range (Eval (Rank s)))) (Eval (EnumFromTo (r + r) (Eval ((Fcf.-) (Eval (Rank ws)) 1))))), + re ~ Eval ((Eval (Range =<< (Rank s))) ++ (Eval (EnumFromTo (r + r) (Eval (Rank ws) - 1)))), i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws) ) => @@ -1824,7 +1825,7 @@ isInfixOf :: ws ~ Eval (ExpandWindows i' s), r ~ Eval (Rank s), i' ~ Eval (Rerank r si), - re ~ Eval ((Fcf.++) (Eval (Range (Eval (Rank s)))) (Eval (EnumFromTo (r + r) (Eval ((Fcf.-) (Eval (Rank ws)) 1))))), + re ~ Eval ((Eval (Range =<< Rank s)) ++ (Eval (EnumFromTo (r + r) ((Eval (Rank ws)) - 1)))), i' ~ Eval (DeleteDims re ws), s' ~ Eval (GetDims re ws) ) => @@ -2085,7 +2086,7 @@ elongate :: (KnownNats s, KnownNats s', s' ~ Eval (InsertDim d 1 s)) => - SNat d -> + Dim d -> Array s a -> Array s' a elongate _ a = unsafeModifyShape a @@ -2115,7 +2116,7 @@ inflate :: (KnownNats s, KnownNats s', s' ~ Eval (InsertDim d x s)) => - SNat d -> + Dim d -> SNat x -> Array s a -> Array s' a @@ -2133,7 +2134,7 @@ concats :: (KnownNats s, KnownNats s', s' ~ Eval (InsertDim newd (Eval (Size (Eval (GetDims ds s)))) (Eval (DeleteDims ds s)))) => - SNats ds -> + Dims ds -> SNat newd -> Array s a -> Array s' a @@ -2214,7 +2215,7 @@ sorts :: so ~ Eval (GetDims ds s), s ~ Eval (InsertDims ds so si) ) => - SNats ds -> Array s a -> Array s a + Dims ds -> Array s a -> Array s a sorts SNats a = joins (SNats @ds) $ unsafeModifyVector sortV (extracts (SNats @ds) a) -- | The indices into the array if it were sorted by a comparison function along the dimensions supplied. @@ -2232,7 +2233,7 @@ sortsBy :: so ~ Eval (GetDims ds s), s ~ Eval (InsertDims ds so si) ) => - SNats ds -> (Array si a -> Array si b) -> Array s a -> Array s a + Dims ds -> (Array si a -> Array si b) -> Array s a -> Array s a sortsBy SNats c a = joins (SNats @ds) $ unsafeModifyVector (sortByV c) (extracts (SNats @ds) a) -- | The indices into the array if it were sorted along the dimensions supplied. @@ -2249,7 +2250,7 @@ orders :: so ~ Eval (GetDims ds s), s ~ Eval (InsertDims ds so si) ) => - SNats ds -> Array s a -> Array so Int + Dims ds -> Array s a -> Array so Int orders SNats a = unsafeModifyVector orderV (extracts (SNats @ds) a) -- | The indices into the array if it were sorted by a comparison function along the dimensions supplied. @@ -2267,7 +2268,7 @@ ordersBy :: so ~ Eval (GetDims ds s), s ~ Eval (InsertDims ds so si) ) => - SNats ds -> (Array si a -> Array si b) -> Array s a -> Array so Int + Dims ds -> (Array si a -> Array si b) -> Array s a -> Array so Int ordersBy SNats c a = unsafeModifyVector (orderByV c) (extracts (SNats @ds) a) -- | Apply a binary array function to two arrays with matching shapes across the supplied (matching) dimensions. @@ -2311,7 +2312,7 @@ transmit :: KnownNats sib, KnownNats sic, KnownNats sob, - ds ~ Eval (EnumFromTo (Eval (Rank sa)) (Eval ((Fcf.-) (Eval (Rank sb)) 1))), + ds ~ Eval (EnumFromTo (Eval (Rank sa)) (Eval (Rank sb) - 1)), sib ~ Eval (DeleteDims ds sb), sob ~ Eval (GetDims ds sb), sb ~ Eval (InsertDims ds sob sib), @@ -2434,7 +2435,7 @@ uncons :: KnownNats ls, KnownNats os, os ~ Eval (Replicate (Eval (Rank ds)) 1), - ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))), + ls ~ Eval (GetLastPositions ds s), Eval (SlicesOk ds os ls s) ~ True, st ~ Eval (SetDims ds ls s) ) => @@ -2457,7 +2458,7 @@ unsnoc :: ds ~ '[0], Eval (SlicesOk ds os ls s) ~ True, os ~ Eval (Replicate (Eval (Rank ds)) 0), - ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))), + ls ~ Eval (GetLastPositions ds s), si ~ Eval (SetDims ds ls s), sl ~ Eval (DeleteDims ds s) ) => Array s a -> (Array si a, Array sl a) @@ -2485,7 +2486,7 @@ pattern (:<) :: KnownNats os, Eval (SlicesOk ds os ls s) ~ True, os ~ Eval (Replicate (Eval (Rank ds)) 1), - ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))), + ls ~ Eval (GetLastPositions ds s), st ~ Eval (SetDims ds ls s)) => Array sh a -> Array st a -> Array s a pattern x :< xs <- (uncons -> (x, xs)) @@ -2519,7 +2520,7 @@ pattern (:>) :: ds ~ '[0], Eval (SlicesOk ds os ls s) ~ True, os ~ Eval (Replicate (Eval (Rank ds)) 0), - ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))), + ls ~ Eval (GetLastPositions ds s), si ~ Eval (SetDims ds ls s), sl ~ Eval (DeleteDims ds s) ) => diff --git a/src/NumHask/Array/Shape.hs b/src/NumHask/Array/Shape.hs index 54afe42..e2dd9ac 100644 --- a/src/NumHask/Array/Shape.hs +++ b/src/NumHask/Array/Shape.hs @@ -50,9 +50,16 @@ module NumHask.Array.Shape rankOf, sizeOf, Fin (..), + fin, safeFin, Fins (..), toFins, + + -- * Dimensions + Dim, + Dims, + + -- operators rank, Rank, Range, @@ -130,6 +137,7 @@ module NumHask.Array.Shape -- * multiple dimension getDims, GetDims, + GetLastPositions, modifyDims, insertDims, InsertDims, @@ -173,7 +181,7 @@ import GHC.TypeLits (TypeError, ErrorMessage(..)) import Text.Read import Data.Type.Ord hiding (Min, Max) import Unsafe.Coerce -import Fcf hiding (type (&&), type (||), type (+), type (-), type (++)) +import Fcf hiding (type (>), type (<), type (&&), type (||), type (+), type (-), type (++)) import Fcf qualified import Fcf.Class.Foldable import Fcf.Data.List @@ -292,6 +300,18 @@ newtype Fin s instance Show (Fin n) where show (UnsafeFin x) = show x +-- | Construct a Fin +-- Errors on out-of-bounds +-- +-- >>> fin @2 1 +-- 1 +-- +-- >>> fin @2 2 +-- *** Exception: value outside bounds +-- ... +fin :: forall n. (KnownNat n) => Int -> Fin n +fin x = fromMaybe (error "value outside bounds") (safeFin x) + -- | Construct a Fin safely. -- -- >>> safeFin 1 :: Maybe (Fin 2) @@ -323,6 +343,12 @@ instance Show (Fins n) where toFins :: forall s. (KnownNats s) => [Int] -> Maybe (Fins s) toFins xs = bool Nothing (Just (UnsafeFins xs)) (isFins xs (valuesOf @s)) +-- | An SNat (a type-level Nat) that represents an index into an SNats (a type-level [Nat]). The index is a dimension of the shape. +type Dim = SNat + +-- | An SNats (a type-level [Nat]) that represents indexes into an SNats (a type-level [Nat]). The indexes are dimensions of the shape. +type Dims = SNats + -- | Number of dimensions -- -- >>> rank @Int [2,3,4] @@ -336,7 +362,7 @@ rank = length -- >>> :k! Eval (Rank [2,3,4]) -- ... -- = 3 -data Rank :: t a -> Exp Natural +data Rank :: [a] -> Exp Natural type instance Eval (Rank xs) = Eval (Length xs) @@ -382,9 +408,9 @@ rerank r xs = data Rerank :: Nat -> [Nat] -> Exp [Nat] type instance Eval (Rerank r xs) = - If (Eval ((Fcf.>) r (Eval (Rank xs)))) - (Eval (Eval (Replicate (Eval ((Fcf.-) r (Eval (Rank xs)))) 1) Fcf.++ xs)) - (Eval ((Fcf.++) ('[Eval (Size (Eval (Take ((Eval (Rank xs)) - r + 1) xs)))]) + If (r >? (Eval (Rank xs))) + (Eval (Eval (Replicate (r - (Eval (Rank xs))) 1) ++ xs)) + (Eval (('[Eval (Size (Eval (Take ((Eval (Rank xs)) - r + 1) xs)))]) ++ (Eval (Drop (Eval (Rank xs) + 1 - r) xs)))) -- | Enumerate the dimensions of a shape. @@ -620,7 +646,7 @@ type instance Eval (Minimum (x ': xs)) = -- = 0 data Min :: a -> a -> Exp a -type instance Eval (Min a b) = If (Eval (a Fcf.< b)) a b +type instance Eval (Min a b) = If (a a -> Exp a -type instance Eval (Max a b) = If (Eval (a Fcf.> b)) a b +type instance Eval (Max a b) = If (a >? b) a b -- | Check if i is a valid Fin (aka in-bounds index of a dimension) -- @@ -651,7 +677,7 @@ isFin i d = (zero <= i && i + one <= d) data IsFin :: Nat -> Nat -> Exp Bool type instance Eval (IsFin x d) = - Eval ((Fcf.<) x d) + x Nat -> Exp (Maybe (a, Nat)) type instance Eval (EnumFromToHelper b a) = - If (Eval (a Fcf.> b)) + If (a >? b) 'Nothing ('Just '(a, a+1)) @@ -1010,6 +1036,7 @@ type instance Eval (SliceOk d off l s) = Eval (And [ Eval (IsFin off =<< GetDim d s), Eval ((Fcf.<) l =<< GetDim d s), + Eval ((Fcf.<) (off + l) (Eval (GetDim d s) + 1)), Eval (IsDim d s) ]) @@ -1113,6 +1140,19 @@ data GetDims :: [Nat] -> [Nat] -> Exp [Nat] type instance Eval (GetDims xs ds) = Eval (Map (Flip GetDim ds) xs) +-- | Get the index of the last position in the selected dimensions of a shape. Errors on a zero-dimension. +-- +-- >>> :k! Eval (GetLastPositions [2,0] [2,3,4]) +-- ... +-- = [3, 1] +-- >>> :k! Eval (GetLastPositions '[0] '[0]) +-- ... +-- = '[0 GHC.TypeNats.- 1] +data GetLastPositions :: [Nat] -> [Nat] -> Exp [Nat] + +type instance Eval (GetLastPositions ds s) = + Eval (Map (Flip (Fcf.-) 1) (Eval (GetDims ds s))) + -- | modify dimensions of a shape with (separate) functions. -- -- >>> modifyDims [0,1] [(+1), (+5)] [2,3,4]