From 43739b0c27ff20d7baf99c3514f95ab64a1611a8 Mon Sep 17 00:00:00 2001 From: Tony Day Date: Mon, 16 Sep 2024 17:03:12 +1000 Subject: [PATCH] AsSingleton pre-removal --- readme.org | 102 +++++++- src/NumHask/Array/Dynamic.hs | 14 +- src/NumHask/Array/Fixed.hs | 227 ++++++---------- src/NumHask/Array/Shape.hs | 495 +++++++++++++++++++---------------- 4 files changed, 459 insertions(+), 379 deletions(-) diff --git a/readme.org b/readme.org index 30c84cb..6fcc72e 100644 --- a/readme.org +++ b/readme.org @@ -578,7 +578,7 @@ Failed, three modules loaded. #+begin_src haskell-ng :results output a = fmap (1+) $ range [2,3,4] :: D.Array Int pretty a --- :t \d o l a -> backpermute (S.replaceDim d l) (S.modifyDim d (+o)) a +-- :t \d o l a -> backpermute (S.setDim d l) (S.modifyDim d (+o)) a #+end_src #+RESULTS: @@ -1108,9 +1108,9 @@ D.drops [1,0] m S.shapen [] 20 S.flatten [] [] S.deleteDim [] 2 -S.replaceDim 0 1 [] +S.setDim 0 1 [] S.modifyDim 0 (+1) [] -S.replaceDim 1 3 [] +S.setDim 1 3 [] S.reverseIndex [0] [] [] S.rotateIndex [(0,1)] [] [1] #+end_src @@ -1965,7 +1965,101 @@ ghci| ghci| ghci| ghci| ghci| ghci| ghci| ghci| ghci| ghci| ghci| ghci| https://discourse.haskell.org/t/how-to-create-arbitrary-instance-for-dependent-types/6990/5 - ** singletons https://github.com/goldfirere/singletons/blob/master/README.md + + +** dependent type examples + +#+begin_src haskell-ng :results output +example_inserta :: (Show a, FromInteger a) => SomeArray a -> String +example_inserta (SomeArray (SNats :: SNats ns) a) = show (insert (SNat @0) 0 a (toScalar 0)) +-- Could not deduce ‘HasShape +-- (If (s Data.Type.Equality.== '[]) '[1] s)’ +-- arising from a use of ‘insert’ +#+end_src + +segfaults as SNats is somehow SNat @'[] + +#+begin_src haskell-ng :results output +example_take :: forall a s. (HasShape s, Show a) => Nat -> Nat -> Array s a -> String +example_take d t a = + withSomeSNat d + (\(SNat :: SNat d) -> + withSomeSNat t + (\(SNat :: SNat t) -> + case someTakeDim @d @t @s of + SNats -> show $ take (SNat @d) (SNat @t) a)) +#+end_src + +#+begin_src haskell-ng :results output +example_take' :: forall a s. (HasShape s, Show a) => Nat -> Nat -> Array s a -> String +example_take' d t a = + withSomeSNat d + (\(SNat :: SNat d') -> + withSomeSNat t + (\(SNat :: SNat t) -> + case someTakeDim2 @d' @t @s of + x -> show x <> " " <> show (someTakeDim2 @d' @t @s) <> " : take " <> show (SNat @d') <> " " <> show (SNat @t) <> " " <> show a)) +#+end_src + +example_take' 0 1 (range @[2,3,4]) +"SNats @[1, 3, 4] SNats @[1, 3, 4] : take SNat @0 SNat @1 [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23]" + +vector examples + +#+begin_src haskell-ng :results output +data SomeVector a where + SomeVector :: KnownNat n => Vector n a -> SomeVector a + +deriving instance (Show a) => Show (SomeVector a) + +withLength :: forall n a r. Vector n a -> (KnownNat n => r) -> r +withLength v r = case someNatVal (fromIntegral $ V.length (asVector v)) of + SomeNat (Proxy :: Proxy n') -> case unsafeCoerce Refl of + (Refl :: n :~: n') -> r + +aVector :: FromVector t a => t -> SomeVector a +aVector (Array . asVector -> v) = withLength v (SomeVector v) + +example_append :: (Show a, Num a, FromInteger a) => SomeVector a -> String +example_append (SomeVector a) = show (append (SNat @0) a (toScalar 0)) + +example_insert :: (Show a, FromInteger a) => SomeVector a -> String +example_insert (SomeVector a) = show (insert (SNat @0) 0 a (toScalar 0)) + +data SomeVector' a = forall n. SomeVector' (SNat n) (Vector n a) + +deriving instance (Show a) => Show (SomeVector' a) + +someVector' :: FromVector t a => KnownNat n => SNat n -> t -> SomeVector' a +someVector' n t = SomeVector' n (vector' n t) + +aVector' :: forall a t. FromVector t a => t -> SomeVector' a +aVector' t = withSomeSNat (fromIntegral $ V.length (asVector t)) $ \(SNat :: SNat n) -> SomeVector' SNat (vector' (SNat @n) (asVector t)) + +example_insert' :: (Show a, FromInteger a) => SomeVector' a -> String +example_insert' (SomeVector' (SNat :: SNat n) a) = show (insert (SNat @0) 0 a (toScalar 0)) + +example_append' :: (Show a, Num a, FromInteger a) => SomeVector' a -> String +example_append' (SomeVector' (SNat :: SNat n) a) = show (append (SNat @0) a (toScalar 0)) + +instance (Arbitrary a) => Arbitrary (SomeVector' a) where + arbitrary = do + n <- arbitrary + v <- V.replicateM (Prelude.fromIntegral n) arbitrary + withSomeNat n $ \sn -> pure (someVector' sn v) +#+end_src + + +#+begin_src haskell-ng :results output +someTakeDim :: forall d t s. (HasShape s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) +someTakeDim = withSomeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) unsafeCoerce + +someTakeDim2 :: forall d t s. (HasShape s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) +someTakeDim2 = UnsafeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) + +someUnsafeGetIndex :: forall d s. (HasShape s, KnownNat d) => SNat (Eval (UnsafeGetIndex d s)) +someUnsafeGetIndex = withSomeSNat (fromIntegral $ getDim (int (SNat :: SNat d)) (ints (SNats :: SNats s))) unsafeCoerce +#+end_src diff --git a/src/NumHask/Array/Dynamic.hs b/src/NumHask/Array/Dynamic.hs index 4988119..c4d9f09 100644 --- a/src/NumHask/Array/Dynamic.hs +++ b/src/NumHask/Array/Dynamic.hs @@ -723,7 +723,7 @@ drop :: Array a drop d t a = backpermute dsNew (S.modifyDim d (\x -> x + bool t 0 (t < 0))) a where - dsNew = S.replaceDim d ((S.unsafeGetIndex d (shape a)) - abs t) + dsNew = S.setDim d ((S.getDim d (shape a)) - abs t) -- | Select an index along a dimension. -- @@ -823,7 +823,7 @@ append :: Array a -> Array a -> Array a -append d a b = insert d (S.unsafeGetIndex d (shape a)) a b +append d a b = insert d (S.getDim d (shape a)) a b -- | Insert along a dimension at the beginning. -- @@ -864,7 +864,7 @@ slice :: (Int, Int) -> Array a -> Array a -slice d (o, l) a = backpermute (S.replaceDim d l) (S.modifyDim d (+ o)) a +slice d (o, l) a = backpermute (S.setDim d l) (S.modifyDim d (+ o)) a -- * multi-dimension operators @@ -883,8 +883,8 @@ takes :: Array a takes ts a = backpermute dsNew (List.zipWith (+) start) a where - dsNew = S.replaceDims ds xsAbs - start = List.zipWith (\x s -> bool 0 (s + x) (x<0)) (S.replaceDimsT ts (replicate (rank a) 0)) (shape a) + dsNew = S.setDims ds xsAbs + start = List.zipWith (\x s -> bool 0 (s + x) (x<0)) (S.setDimsT ts (replicate (rank a) 0)) (shape a) ds = fmap fst ts xs = fmap snd ts xsAbs = fmap abs xs @@ -900,7 +900,7 @@ drops :: drops ts a = backpermute dsNew (List.zipWith (\d' s' -> bool (d' + s') s' (d' < 0)) xsNew) a where dsNew = S.modifyDims ds (fmap (flip (-)) xsAbs) - xsNew = S.replaceDims ds xs (replicate (rank a) 0) + xsNew = S.setDims ds xs (replicate (rank a) 0) ds = fmap fst ts xs = fmap snd ts xsAbs = fmap abs xs @@ -1620,7 +1620,7 @@ concats :: concats ds n a = backpermute concatDims unconcatDims a where concatDims s = S.insertDim n (S.size $ S.takeDims ds s) (S.deleteDims ds s) - unconcatDims s = S.insertDims (List.zip ds (S.shapen (S.takeDims ds (shape a)) (S.unsafeGetIndex n s))) (S.deleteDim n s) + unconcatDims s = S.insertDims (List.zip ds (S.shapen (S.takeDims ds (shape a)) (S.getDim n s))) (S.deleteDim n s) -- | Rotate an array along a dimension. -- diff --git a/src/NumHask/Array/Fixed.hs b/src/NumHask/Array/Fixed.hs index 6fcd1f6..ab0ea94 100644 --- a/src/NumHask/Array/Fixed.hs +++ b/src/NumHask/Array/Fixed.hs @@ -79,9 +79,6 @@ module NumHask.Array.Fixed -- ** Single-dimension operators take, - example_take, - example_take_1, - example_take_axiom, takeB, drop, dropB, @@ -185,16 +182,6 @@ module NumHask.Array.Fixed Vector, vector, vector', - SomeVector (..), - withLength, - aVector, - example_insert, - example_append, - SomeVector' (..), - someVector', - aVector', - example_insert', - example_append', iota, Matrix, ) @@ -220,12 +207,8 @@ import System.Random hiding (uniform) import System.Random.Stateful hiding (uniform) import Unsafe.Coerce import NumHask.Array.Sort --- import Data.Reflection --- import Unsafe.Coerce -import Type.Reflection import Test.QuickCheck hiding (tabulate, vector) import Test.QuickCheck.Instances.Natural () -import Data.Constraint ((\\)) -- $setup -- @@ -484,21 +467,34 @@ unsafeModifyShape a = unsafeArray (asVector a) unsafeModifyVector :: (HasShape s) => (FromVector u a) => (FromVector v b) => (u -> v) -> Array s a -> Array s b unsafeModifyVector f a = unsafeArray (asVector (f (vectorAs (asVector a)))) +-- | A fixed Array with a hidden shape. +-- +-- The library design encourages the use of dynamic arrays in preference to dependent-type styles such as this. In particular, no attempt has been made to prove to the compiler that a particular Shape (resulting from any of the supplied functions) exists. Life is short. data SomeArray a = forall s. SomeArray (SNats s) (Array s a) --- TODO: other derivings deriving instance (Show a) => Show (SomeArray a) +instance Functor SomeArray + where + fmap f (SomeArray sn a) = SomeArray sn (fmap f a) + +instance Foldable SomeArray + where + foldMap f (SomeArray _ a) = foldMap f a + + someArray :: forall s t a. FromVector t a => SNats s -> t -> SomeArray a someArray s t = SomeArray s (Array (asVector t)) --- TODO: debug this --- show . fmap (\(SomeArray _ a) -> sum (asVector a)) <$> (sample' arbitrary :: IO [SomeArray Int]) +-- | +-- > P.take 4 <$> sample' arbitrary :: IO [SomeArray Int] +-- [SomeArray SNats @'[] [0],SomeArray SNats @'[0] [],SomeArray SNats @[1, 1] [1],SomeArray SNats @[5, 1, 4] [2,1,0,2,-6,0,5,6,-1,-4,0,5,-1,6,4,-6,1,0,3,-1]] instance (Arbitrary a) => Arbitrary (SomeArray a) where arbitrary = do s <- arbitrary :: Gen [Small Nat] - v <- V.replicateM (product (Prelude.fromIntegral <$> s)) arbitrary - withSomeSNats (Prelude.take 2 (getSmall <$> s)) $ \s' -> pure (someArray s' v) + let s' = Prelude.take 3 (getSmall <$> s) + v <- V.replicateM (product (Prelude.fromIntegral <$> s')) arbitrary + withSomeSNats s' $ \sn -> pure (someArray sn v) -- | convert to a dynamic array with shape at the value level. -- @@ -806,7 +802,7 @@ colWise f _ a = f (Proxy :: Proxy ts) a -- [16], -- [20]]] take :: - forall s s' a d t. + forall d t s s' a. ( HasShape s, HasShape s', s' ~ Eval (TakeDim d t s) @@ -817,35 +813,6 @@ take :: Array s' a take _ _ a = unsafeBackpermute id a -example_take :: forall a s. (HasShape s, Show a) => Nat -> Nat -> Array s a -> String -example_take d t a = - withSomeSNat d - (\(SNat :: SNat d) -> - withSomeSNat t - (\(SNat :: SNat t) -> - case someTakeDim @d @t @s of - SNats -> show $ take (SNat @d) (SNat @t) a - _ -> "not matched")) - -example_take_1 :: forall a s. (HasShape s, Show a) => Nat -> Nat -> Array s a -> String -example_take_1 d t a = - withSomeSNat d - (\(SNat :: SNat d) -> - withSomeSNat t - (\(SNat :: SNat t) -> - case someTakeDim @d @t @s of - SNats -> show (someTakeDim @d @t @s) <> " : take " <> show (SNat @d) <> " " <> show (SNat @t) <> " " <> show a - _ -> "not matched")) - -example_take_axiom :: forall a s. (HasShape s, Show a) => Nat -> Nat -> Array s a -> String -example_take_axiom d t a = - withSomeSNat d - (\(SNat :: SNat d) -> - withSomeSNat t - (\(SNat :: SNat t) -> - (show $ (take (SNat @d) (SNat @t) a) \\ axiomTakeDim @d @t @s) - )) - -- | Take the bottom-most elements across the specified dimension. -- -- >>> pretty $ takeB (SNat @2) (SNat @1) a @@ -859,14 +826,13 @@ takeB :: forall s s' a d t. ( HasShape s, HasShape s', - Eval (IsFin d (Eval (Rank s))) ~ True, - Eval (SetIndex d (Eval (Min t (Eval (UnsafeGetIndex d s)))) s) ~ s' + s' ~ Eval (TakeDim d t s) ) => SNat d -> SNat t -> Array s a -> Array s' a -takeB d t a = unsafeBackpermute (\s -> modifyDim (int d) (\x -> x + (unsafeGetIndex (int d) (shape a)) - (int t)) s) a +takeB d t a = unsafeBackpermute (\s -> modifyDim (int d) (\x -> x + (getDim (int d) (shape a)) - (int t)) s) a -- | Drop the top-most elements across the specified dimension. -- @@ -881,8 +847,7 @@ drop :: forall s s' a d t. ( HasShape s, HasShape s', - Eval (IsFin d (Eval (Rank s))) ~ True, - Eval (SetIndex d (Eval ((Fcf.-) (Eval (UnsafeGetIndex d s)) t)) s) ~ s' + Eval (DropDim d t s) ~ s' ) => SNat d -> SNat t -> @@ -903,8 +868,7 @@ dropB :: forall s s' a d t. ( HasShape s, HasShape s', - Eval (IsFin d (Eval (Rank s))) ~ True, - Eval (SetIndex d (Eval ((Fcf.-) (Eval (UnsafeGetIndex d s)) t)) s) ~ s' + Eval (DropDim d t s) ~ s' ) => SNat d -> SNat t -> @@ -939,11 +903,10 @@ select d x a = unsafeBackpermute (S.insertDim (int d) (int x)) a -- UnsafeArray [4] [0,1,2,3] concatenate :: forall a s0 s1 d s. - ( Eval (Concatenate d (Eval (AsSingleton s0)) (Eval (AsSingleton s1))) ~ s, - HasShape s0, + ( HasShape s0, HasShape s1, HasShape s, - HasShape (Eval (AsSingleton s0)) + Eval (Concatenate d s0 s1) ~ s ) => SNat d -> Array s0 a -> @@ -958,12 +921,12 @@ concatenate d a0 a1 = tabulate (go . fromFins) a1 ( UnsafeFins $ insertDim d' - ((s !! d') - (ds0 !! d')) + (getDim d' s - getDim d' ds0) (deleteDim d' s) ) ) - ((s !! d') >= (ds0 !! d')) - ds0 = shape (asSingleton a0) + (getDim d' s >= getDim d' ds0) + ds0 = shape a0 d' = int d -- | Insert along a dimension at a position. @@ -982,10 +945,10 @@ insert :: (HasShape s, HasShape si, HasShape s', + si ~ Eval (DeleteDim d s), HasShape (Eval (AsSingleton s)), HasShape (Eval (AsSingleton si)), - -- FIXME si relationship - s' ~ Eval (IncAt d (Eval (AsSingleton s))) + s' ~ Eval (IncAt d s) ) => SNat d -> Int -> @@ -1001,6 +964,7 @@ insert sd i a b = tabulate go where xs' = fromFins xs d = Prelude.fromIntegral (fromSNat sd) + -- | Delete along a dimension at a position. -- -- FIXME: What does this do??? @@ -1038,9 +1002,10 @@ append :: forall a d pos s si s'. ( HasShape (Eval (AsSingleton s)), HasShape (Eval (AsSingleton si)), - s' ~ Eval (IncAt d (Eval (AsSingleton s))), + si ~ Eval (DeleteDim d s), + s' ~ Eval (IncAt d s), KnownNat pos, - pos ~ Eval (UnsafeGetIndex d s), + pos ~ Eval (GetDim d s), HasShape s, HasShape si, HasShape s' @@ -1064,8 +1029,9 @@ prepend :: forall a d pos s si s'. ( HasShape (Eval (AsSingleton s)), HasShape (Eval (AsSingleton si)), - s' ~ Eval (IncAt d (Eval (AsSingleton s))), - pos ~ Eval ((Fcf.-) (Eval (UnsafeGetIndex d s)) 1), + si ~ Eval (DeleteDim d s), + s' ~ Eval (IncAt d s), + pos ~ Eval ((Fcf.-) (Eval (GetDim d s)) 1), HasShape s, HasShape si, HasShape s' @@ -1106,7 +1072,7 @@ slice :: forall a d off l s s'. (HasShape s, HasShape s', - Eval (SetIndex d l s) ~ s') => + Eval (SetDim d l s) ~ s') => SNat d -> SNat off -> SNat l -> @@ -1123,7 +1089,7 @@ takes :: forall ts s' s a. ( HasShape s, HasShape s', - s' ~ Eval (ReplaceDimsT ts s) + s' ~ Eval (SetDimsT ts s) ) => Proxy ts -> Array s a -> @@ -1139,7 +1105,7 @@ takeBs :: forall ts s' s a ds xs. ( HasShape s, HasShape s', - s' ~ Eval (ReplaceDimsT ts s), + s' ~ Eval (SetDimsT ts s), ds ~ Eval (Map Fst ts), xs ~ Eval (Map Snd ts), HasShape ds, @@ -1150,7 +1116,7 @@ takeBs :: Array s' a takeBs _ a = unsafeBackpermute (List.zipWith (+) start) a where - start = List.zipWith (-) (shape a) (S.replaceDimsT (zip (shapeOf @ds) (shapeOf @xs)) (shape a)) + start = List.zipWith (-) (shape a) (S.setDimsT (zip (shapeOf @ds) (shapeOf @xs)) (shape a)) -- | Drops the top-most elements across dimension,n tuples. -- @@ -1304,7 +1270,7 @@ tails :: ts ~ Eval (Zip ds (Eval (Zip os ls))), os ~ Eval (Replicate (Eval (Rank ds)) 1), ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (TakeDims ds s))), - s' ~ Eval (ReplaceDims ds ls s), + s' ~ Eval (SetDims ds ls s), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts))) @@ -1330,7 +1296,7 @@ inits :: ts ~ Eval (Zip ds (Eval (Zip os ls))), os ~ Eval (Replicate (Eval (Rank ds)) 0), ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (TakeDims ds s))), - s' ~ Eval (ReplaceDims ds ls s), + s' ~ Eval (SetDims ds ls s), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts))) @@ -1359,13 +1325,13 @@ slices :: ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts))), - Eval (ReplaceDims ds ls s) ~ s') => + Eval (SetDims ds ls s) ~ s') => Proxy ts -> Array s a -> Array s' a slices _ a = unsafeBackpermute (List.zipWith (+) o) a where - o = S.replaceDims (shapeOf @ds) (shapeOf @os) (replicate (rank a) 0) + o = S.setDims (shapeOf @ds) (shapeOf @os) (replicate (rank a) 0) -- | Extracts dimensions to an outer layer. -- @@ -1444,7 +1410,7 @@ joins :: HasShape st, HasShape si, HasShape so, - Eval (InsertDims (Eval (Zip ds so)) si) ~ st) => + Eval (InsertDims ds so si) ~ st) => Proxy ds -> Array so (Array si a) -> Array st a @@ -1476,8 +1442,8 @@ traverses :: (Applicative f, HasShape s, HasShape s', - s' ~ Eval (InsertDims (Eval (Zip ds (Eval (TakeDims ds s)))) (Eval (DeleteDims ds s))), - HasShape (Eval (InsertDims (Eval (Zip ds (Eval (TakeDims ds s)))) (Eval (DeleteDims ds s)))), + s' ~ Eval (InsertDims ds (Eval (TakeDims ds s)) (Eval (DeleteDims ds s))), + HasShape s', HasShape (Eval (DeleteDims ds s)), HasShape (Eval (TakeDims ds s)), HasShape ds) => @@ -1501,8 +1467,8 @@ maps :: HasShape so, si ~ Eval (DeleteDims ds st), so ~ Eval (TakeDims ds st), - st' ~ Eval (InsertDims (Eval (Zip ds so)) si'), - st ~ Eval (InsertDims (Eval (Zip ds so)) si) + st' ~ Eval (InsertDims ds so si'), + st ~ Eval (InsertDims ds so si) ) => (Array si a -> Array si' b) -> Proxy ds -> @@ -1547,8 +1513,8 @@ zips :: HasShape so, si ~ Eval (DeleteDims ds st), so ~ Eval (TakeDims ds st), - st' ~ Eval (InsertDims (Eval (Zip ds so)) si'), - st ~ Eval (InsertDims (Eval (Zip ds so)) si) + st' ~ Eval (InsertDims ds so si'), + st ~ Eval (InsertDims ds so si) ) => Proxy ds -> (Array si a -> Array si b -> Array si' c) -> @@ -1577,7 +1543,7 @@ modifies :: ps ~ Eval (Map Snd ts), si ~ Eval (DeleteDims ds s), so ~ Eval (TakeDims ds s), - s ~ Eval (InsertDims (Eval (Zip ds so)) si)) => + s ~ Eval (InsertDims ds so si)) => (Array si a -> Array si a) -> Proxy ts -> Array s a -> @@ -1603,8 +1569,8 @@ diffs :: HasShape postDrop, si ~ Eval (DeleteDims ds postDrop), so ~ Eval (TakeDims ds postDrop), - st' ~ Eval (InsertDims (Eval (Zip ds so)) si'), - postDrop ~ Eval (InsertDims (Eval (Zip ds so)) si), + st' ~ Eval (InsertDims ds so si'), + postDrop ~ Eval (InsertDims ds so si), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd ts), postDrop ~ Eval (DropDims ts st) @@ -2209,7 +2175,7 @@ concats :: Array s' a concats _ newd a = unsafeBackpermute unconcatDims a where - unconcatDims s = S.insertDims (List.zip ds (S.shapen (S.takeDims ds (shape a)) (S.unsafeGetIndex n s))) (S.deleteDim n s) + unconcatDims s = S.insertDims (List.zip ds (S.shapen (S.takeDims ds (shape a)) (S.getDim n s))) (S.deleteDim n s) n = int newd ds = shapeOf @ds @@ -2283,7 +2249,7 @@ sorts :: HasShape so, si ~ Eval (DeleteDims ds s), so ~ Eval (TakeDims ds s), - s ~ Eval (InsertDims (Eval (Zip ds so)) si) + s ~ Eval (InsertDims ds so si) ) => Proxy ds -> Array s a -> Array s a sorts ds a = joins ds $ unsafeModifyVector sortV (extracts ds a) @@ -2302,7 +2268,7 @@ sortsBy :: HasShape so, si ~ Eval (DeleteDims ds s), so ~ Eval (TakeDims ds s), - s ~ Eval (InsertDims (Eval (Zip ds so)) si) + s ~ Eval (InsertDims ds so si) ) => Proxy ds -> (Array si a -> Array si b) -> Array s a -> Array s a sortsBy ds c a = joins ds $ unsafeModifyVector (sortByV c) (extracts ds a) @@ -2320,7 +2286,7 @@ orders :: HasShape so, si ~ Eval (DeleteDims ds s), so ~ Eval (TakeDims ds s), - s ~ Eval (InsertDims (Eval (Zip ds so)) si) + s ~ Eval (InsertDims ds so si) ) => Proxy ds -> Array s a -> Array so Int orders ds a = unsafeModifyVector orderV (extracts ds a) @@ -2339,7 +2305,7 @@ ordersBy :: HasShape so, si ~ Eval (DeleteDims ds s), so ~ Eval (TakeDims ds s), - s ~ Eval (InsertDims (Eval (Zip ds so)) si) + s ~ Eval (InsertDims ds so si) ) => Proxy ds -> (Array si a -> Array si b) -> Array s a -> Array so Int ordersBy ds c a = unsafeModifyVector (orderByV c) (extracts ds a) @@ -2394,8 +2360,8 @@ transmit :: ds ~ Eval (EnumFromTo (Eval (Rank sa)) (Eval ((Fcf.-) (Eval (Rank sb)) 1))), sib ~ Eval (DeleteDims ds sb), sob ~ Eval (TakeDims ds sb), - sb ~ Eval (InsertDims (Eval (Zip ds sob)) sib), - sc ~ Eval (InsertDims (Eval (Zip ds sob)) sic), + sb ~ Eval (InsertDims ds sob sib), + sc ~ Eval (InsertDims ds sob sic), True ~ (Eval (IsPrefixOf sa sb))) => (Array sa a -> Array sib b -> Array sic c) -> Array sa a -> Array sb b -> Array sc c transmit f a b = maps (f a) (Proxy :: Proxy ds) b @@ -2403,47 +2369,6 @@ transmit f a b = maps (f a) (Proxy :: Proxy ds) b -- | type Vector s a = Array '[s] a -data SomeVector a where - SomeVector :: KnownNat n => Vector n a -> SomeVector a - -deriving instance (Show a) => Show (SomeVector a) - -withLength :: forall n a r. Vector n a -> (KnownNat n => r) -> r -withLength v r = case someNatVal (fromIntegral $ V.length (asVector v)) of - SomeNat (Proxy :: Proxy n') -> case unsafeCoerce Refl of - (Refl :: n :~: n') -> r - -aVector :: FromVector t a => t -> SomeVector a -aVector (Array . asVector -> v) = withLength v (SomeVector v) - -example_append :: (Show a, Num a, FromInteger a) => SomeVector a -> String -example_append (SomeVector a) = show (append (SNat @0) a (toScalar 0)) - -example_insert :: (Show a, FromInteger a) => SomeVector a -> String -example_insert (SomeVector a) = show (insert (SNat @0) 0 a (toScalar 0)) - -data SomeVector' a = forall n. SomeVector' (SNat n) (Vector n a) - -deriving instance (Show a) => Show (SomeVector' a) - -someVector' :: FromVector t a => KnownNat n => SNat n -> t -> SomeVector' a -someVector' n t = SomeVector' n (vector' n t) - -aVector' :: forall a t. FromVector t a => t -> SomeVector' a -aVector' t = withSomeSNat (fromIntegral $ V.length (asVector t)) $ \(SNat :: SNat n) -> SomeVector' SNat (vector' (SNat @n) (asVector t)) - -example_insert' :: (Show a, FromInteger a) => SomeVector' a -> String -example_insert' (SomeVector' (SNat :: SNat n) a) = show (insert (SNat @0) 0 a (toScalar 0)) - -example_append' :: (Show a, Num a, FromInteger a) => SomeVector' a -> String -example_append' (SomeVector' (SNat :: SNat n) a) = show (append (SNat @0) a (toScalar 0)) - -instance (Arbitrary a) => Arbitrary (SomeVector' a) where - arbitrary = do - n <- arbitrary - v <- V.replicateM (Prelude.fromIntegral n) arbitrary - withSomeNat n $ \sn -> pure (someVector' sn v) - -- | A one-dimensional array. -- -- >>> pretty $ vector @3 @Int [2,3,4] @@ -2522,8 +2447,9 @@ cons :: KnownNat pos, HasShape (Eval (AsSingleton st)), HasShape (Eval (AsSingleton sh)), - s ~ Eval (IncAt 0 (Eval (AsSingleton st))), - pos ~ Eval ((Fcf.-) (Eval (UnsafeGetIndex 0 st)) 1)) => + s ~ Eval (IncAt 0 st), + sh ~ Eval (DeleteDim 0 st), + pos ~ Eval ((Fcf.-) (Eval (GetDim 0 st)) 1)) => Array sh a -> Array st a -> Array s a cons = prepend (SNat @0) @@ -2540,9 +2466,10 @@ snoc :: forall si s sl a pos. HasShape sl, HasShape (Eval (AsSingleton si)), HasShape (Eval (AsSingleton sl)), - s ~ Eval (IncAt 0 (Eval (AsSingleton si))), + s ~ Eval (IncAt 0 si), + sl ~ Eval (DeleteDim 0 si), KnownNat pos, - pos ~ Eval (UnsafeGetIndex 0 si)) => + pos ~ Eval (GetDim 0 si)) => Array si a -> Array sl a -> Array s a snoc = append (SNat @0) @@ -2565,7 +2492,7 @@ uncons :: ts ~ Eval (Zip ds (Eval (Zip os ls))), os ~ Eval (Replicate (Eval (Rank ds)) 1), ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (TakeDims ds (Eval (AsSingleton s))))), - st ~ Eval (ReplaceDims ds ls (Eval (AsSingleton s))), + st ~ Eval (SetDims ds ls (Eval (AsSingleton s))), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts))) @@ -2593,7 +2520,7 @@ unsnoc :: ts ~ Eval (Zip ds (Eval (Zip os ls))), os ~ Eval (Replicate (Eval (Rank ds)) 0), ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (TakeDims ds (Eval (AsSingleton s))))), - si ~ Eval (ReplaceDims ds ls (Eval (AsSingleton s))), + si ~ Eval (SetDims ds ls (Eval (AsSingleton s))), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts))), @@ -2620,18 +2547,19 @@ pattern (:<) :: KnownNat pos, HasShape (Eval (AsSingleton st)), HasShape (Eval (AsSingleton sh)), - s ~ Eval (IncAt 0 (Eval (AsSingleton st))), - pos ~ Eval ((Fcf.-) (Eval (UnsafeGetIndex 0 st)) 1), + s ~ Eval (IncAt 0 st), + pos ~ Eval ((Fcf.-) (Eval (GetDim 0 st)) 1), ds ~ '[0], HasShape (Eval (AsSingleton s)), sh ~ Eval (DeleteDims ds (Eval (AsSingleton s))), sh ~ Eval (Drop 1 s), + sh ~ Eval (DeleteDim 0 st), HasShape ls, HasShape os, ts ~ Eval (Zip ds (Eval (Zip os ls))), os ~ Eval (Replicate (Eval (Rank ds)) 1), ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (TakeDims ds (Eval (AsSingleton s))))), - st ~ Eval (ReplaceDims ds ls (Eval (AsSingleton s))), + st ~ Eval (SetDims ds ls (Eval (AsSingleton s))), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts)))) => @@ -2660,18 +2588,19 @@ pattern (:>) :: HasShape s, HasShape (Eval (AsSingleton si)), HasShape (Eval (AsSingleton sl)), - s ~ Eval (IncAt 0 (Eval (AsSingleton si))), + s ~ Eval (IncAt 0 si), KnownNat pos, - pos ~ Eval (UnsafeGetIndex 0 si), + pos ~ Eval (GetDim 0 si), HasShape ds, HasShape ls, HasShape os, HasShape (Eval (AsSingleton s)), + sl ~ Eval (DeleteDim 0 si), ds ~ '[0], ts ~ Eval (Zip ds (Eval (Zip os ls))), os ~ Eval (Replicate (Eval (Rank ds)) 0), ls ~ Eval (Map (Flip (Fcf.-) 1) (Eval (TakeDims ds (Eval (AsSingleton s))))), - si ~ Eval (ReplaceDims ds ls (Eval (AsSingleton s))), + si ~ Eval (SetDims ds ls (Eval (AsSingleton s))), ds ~ Eval (Map Fst ts), ls ~ Eval (Map Snd (Eval (Map Snd ts))), os ~ Eval (Map Fst (Eval (Map Snd ts))), diff --git a/src/NumHask/Array/Shape.hs b/src/NumHask/Array/Shape.hs index d87c004..5977784 100644 --- a/src/NumHask/Array/Shape.hs +++ b/src/NumHask/Array/Shape.hs @@ -40,6 +40,8 @@ module NumHask.Array.Shape ints, KnownNats (..), natVals, + SomeNats, + someNatVals, withSomeSNats, HasShape, shapeOf, @@ -59,10 +61,6 @@ module NumHask.Array.Shape AsSingleton, asScalar, AsScalar, - GetIndex, - unsafeGetIndex, - UnsafeGetIndex, - SetIndex, rotate, type (++), Take, @@ -79,11 +77,6 @@ module NumHask.Array.Shape Min, minimum, Minimum, - modifyDim, - ModifyDim, - replaceDim, - ReplaceDim, - ReplaceDimUncurried, incAt, IncAt, decAt, @@ -91,19 +84,16 @@ module NumHask.Array.Shape insertDim, InsertDim, InsertDimUncurried, - deleteDim, - DeleteDim, preDeletePositions, PreDeletePositions, preInsertPositions, PreInsertPositions, insertDims, InsertDims, - InsertDimsHelper, - replaceDims, - ReplaceDims, - replaceDimsT, - ReplaceDimsT, + setDims, + SetDims, + setDimsT, + SetDimsT, modifyDims, deleteDims, DeleteDims, @@ -140,12 +130,21 @@ module NumHask.Array.Shape EnumFromTo, Foldl', - -- * Specific Constraints + -- * Shape dimension manipulation + GetIndex, + SetIndex, + getDim, + GetDim, + modifyDim, + ModifyDim, + setDim, + SetDim, takeDim, TakeDim, - someTakeDim, - someTakeDim', - axiomTakeDim, + dropDim, + DropDim, + deleteDim, + DeleteDim, ) where @@ -155,7 +154,7 @@ import Data.Type.Bool hiding (Not) import Data.Type.Equality import GHC.TypeLits qualified as L import Prelude qualified -import NumHask.Prelude as P hiding (Min, Last, minimum) +import NumHask.Prelude as P hiding (Min, Max, Last, minimum) import Data.Coerce import Data.Data import GHC.Arr @@ -163,7 +162,7 @@ import GHC.Exts import GHC.TypeNats import GHC.TypeLits (TypeError, ErrorMessage(..)) import Text.Read -import Data.Type.Ord hiding (Min) +import Data.Type.Ord hiding (Min, Max) import Unsafe.Coerce import Fcf hiding (type (&&), type (+), type (-), type (++)) import Fcf qualified @@ -171,7 +170,6 @@ import Fcf.Class.Foldable import Fcf.Data.List import Fcf.Combinators import Control.Monad -import Data.Constraint (Dict (Dict), (:-) (Sub), (\\)) -- $setup -- >>> :m -Prelude @@ -221,6 +219,7 @@ type role SNats nominal pattern SNats :: forall ns. () => KnownNats ns => SNats ns pattern SNats <- (knownNatsInstance -> KnownNatsInstance) where SNats = natsSing +{-# COMPLETE SNats #-} fromSNats :: SNats s -> [Nat] fromSNats (UnsafeSNats s) = s @@ -265,6 +264,11 @@ withSomeSNats :: forall rep (r :: TYPE rep). [Nat] -> (forall s. SNats s -> r) -> r withSomeSNats s k = k (UnsafeSNats s) +data SomeNats = forall s. KnownNats s => SomeNats (Proxy s) + +someNatVals :: [Nat] -> SomeNats +someNatVals s = withSomeSNats s (\(sn :: SNats s) -> + withKnownNats sn (SomeNats @s Proxy)) {- -- | The Shape type holds a [Nat] at type level and the equivalent [Int] at value level. @@ -430,7 +434,7 @@ size xs = P.product xs -- >>> :k! (Eval (Size [2,3,4])) -- (Eval (Size [2,3,4])) :: Natural -- = 24 -data Size :: t Nat -> Exp Nat +data Size :: [Nat] -> Exp Nat type instance Eval (Size xs) = Eval (Foldr (Fcf.*) 1 xs) @@ -495,7 +499,7 @@ inside xs ds = (rank xs == rank ds) && (List.and $ List.zipWith (\x d -> x >= ze -- >>> :k! Eval (Inside [2,1] '[1]) -- Eval (Inside [2,1] '[1]) :: Bool -- = False -data Inside :: t Nat -> t Nat -> Exp Bool +data Inside :: [Nat] -> [Nat] -> Exp Bool type instance Eval (Inside xs ds) = Eval (LiftM2 (Fcf.&&) @@ -513,7 +517,7 @@ type instance Eval (Inside xs ds) = -- >>> :k! Eval (ShapeLTE [2,1] '[1]) -- Eval (ShapeLTE [2,1] '[1]) :: Bool -- = False -data ShapeLTE :: t Nat -> t Nat -> Exp Bool +data ShapeLTE :: [Nat] -> [Nat] -> Exp Bool type instance Eval (ShapeLTE xs ys) = Eval (LiftM2 (Fcf.&&) @@ -537,7 +541,7 @@ asSingleton x = x -- >>> :k! Eval (AsSingleton [2,3,4]) -- ... -- = [2, 3, 4] -data AsSingleton :: t Nat -> Exp (t Nat) +data AsSingleton :: [Nat] -> Exp [Nat] type instance Eval (AsSingleton xs) = If (xs == '[]) '[1] xs @@ -559,7 +563,7 @@ asScalar x = x -- >>> :k! Eval (AsScalar [2,3,4]) -- ... -- = [2, 3, 4] -data AsScalar :: t Nat -> Exp (t Nat) +data AsScalar :: [Nat] -> Exp [Nat] type instance Eval (AsScalar xs) = If (xs == '[1]) '[] xs @@ -575,42 +579,6 @@ rotate r xs = drop r' xs <> take r' xs where r' = r `mod` List.length xs --- | Get an element at a given index. --- --- >>> :kind! Eval (GetIndex 2 [2,3,4]) --- ... --- = Just 4 -data GetIndex :: Nat -> [a] -> Exp (Maybe a) -type instance Eval (GetIndex n xs) = GetIndexImpl n xs - -type family GetIndexImpl (n :: Nat) (xs :: [k]) where - GetIndexImpl _ '[] = 'Nothing - GetIndexImpl 0 (x ': _) = 'Just x - GetIndexImpl n (_ ': xs) = GetIndexImpl (n - 1) xs - --- | unsafeGetIndex i xs is the i'th element of xs (or error if out-of-bounds) --- --- >>> unsafeGetIndex 1 [2,3,4] --- 3 --- >>> unsafeGetIndex 3 [2,3,4] --- *** Exception: unsafeGetIndex outside bounds --- ... -unsafeGetIndex :: Int -> [Int] -> Int -unsafeGetIndex 0 (s : _) = s -unsafeGetIndex n (_ : s) = unsafeGetIndex (n - 1) s -unsafeGetIndex _ _ = error "unsafeGetIndex outside bounds" - --- | UnsafeGetIndex i xs is the i'th element of xs (or error if out-of-bounds) --- --- >>> :k! Eval (UnsafeGetIndex 1 [2,3,4]) --- ... --- = 3 --- >>> :k! Eval (UnsafeGetIndex 3 [2,3,4]) --- ... --- = (TypeError ...) -data UnsafeGetIndex :: Nat -> [a] -> Exp a -type instance Eval (UnsafeGetIndex n xs) = Eval (FromMaybe (L.TypeError (L.Text "UnsafeGetIndex out of bounds: " :<>: ShowType n :<>: L.Text " " :<>: ShowType xs)) (Eval (GetIndex n xs))) - -- | minimum dimension -- -- >>> S.minimum [] @@ -641,110 +609,31 @@ data Min :: a -> a -> Exp a type instance Eval (Min a b) = If (Eval (a Fcf.< b)) a b --- | delete the i'th dimension --- --- >>> deleteDim 1 [2, 3, 4] --- [2,4] --- >>> deleteDim 2 [] --- [] -deleteDim :: Int -> [Int] -> [Int] -deleteDim i s = take i s ++ drop (i + 1) s - --- | delete the i'th dimension --- --- >>> :k! Eval (DeleteDim 1 [2, 3, 4]) --- ... --- = [2, 4] --- >>> :k! Eval (DeleteDim 1 '[]) --- ... --- = '[] -data DeleteDim :: Nat -> [Nat] -> Exp [Nat] +data Max :: a -> a -> Exp a -type instance Eval (DeleteDim i ds) = - Eval (LiftM2 (Fcf.++) (Take i ds) (Drop (i + 1) ds)) +type instance Eval (Max a b) = If (Eval (a Fcf.> b)) a b --- | /insertDim d i s/ inserts a new dimension of value i to shape /s/ at position /d/ --- --- >>> insertDim 1 3 [2,4] --- [2,3,4] --- >>> insertDim 0 4 [] --- [4] -insertDim :: Int -> Int -> [Int] -> [Int] -insertDim d i s = take d s ++ (i : drop d s) - --- | /insertDim d i s/ inserts a new dimension of value i to shape /s/ at position /d/ --- --- >>> :k! Eval (InsertDim 1 3 [2,4]) --- ... --- = [2, 3, 4] --- >>> :k! Eval (InsertDim 0 4 '[]) --- ... --- = '[4] -data InsertDim :: Nat -> Nat -> [Nat] -> Exp [Nat] - -type instance Eval (InsertDim d i ds) = - Eval (Eval (Take d ds) Fcf.++ (i ': Eval (Drop d ds))) - -data InsertDimUncurried :: (Nat,Nat) -> [Nat] -> Exp [Nat] - -type instance Eval (InsertDimUncurried xs ds) = - Eval (Eval (Take (Eval (Fst xs)) ds) Fcf.++ ((Eval (Snd xs)) ': Eval (Drop (Eval (Fst xs)) ds))) - --- | modify an index at a specific dimension. Unmodified if out of bounds. --- --- >>> modifyDim 0 (+1) [0,1,2] --- [1,1,2] -modifyDim :: Int -> (Int -> Int) -> [Int] -> [Int] -modifyDim d f xs = take d xs <> (pure . f) (xs List.!! d) <> drop (d + 1) xs - --- | modify an index at a specific dimension. Unmodified if out of bounds. --- --- >>> :k! Eval (ModifyDim 0 ((Fcf.+) 1) [0,1,2]) --- ... --- = [1, 1, 2] -data ModifyDim :: Nat -> (Nat -> Exp Nat) -> [Nat] -> Exp [Nat] - -type instance Eval (ModifyDim d f ds) = - Eval (FromMaybe ds =<< (Map (Flip (SetIndex d) ds) =<< (Map f =<< (GetIndex d ds)))) - --- | replace an index at a specific dimension. --- --- >>> replaceDim 0 1 [2,3,4] --- [1,3,4] -replaceDim :: Int -> Int -> [Int] -> [Int] -replaceDim d x xs = modifyDim d (const x) xs - --- | replace an index at a specific dimension. --- --- >>> :k! Eval (ReplaceDim 0 1 [2,3,4]) --- ... --- = [1, 3, 4] -data ReplaceDim :: Nat -> Nat -> [Nat] -> Exp [Nat] - -type instance Eval (ReplaceDim d x ds) = - Eval (SetIndex d x ds) - -data ReplaceDimUncurried :: (Nat,Nat) -> [Nat] -> Exp [Nat] - -type instance Eval (ReplaceDimUncurried xs ds) = - Eval (SetIndex (Eval (Fst xs)) (Eval (Snd xs)) ds) - --- | Increment the index at a dimension of a shape by one. +-- | Increment the index at a dimension of a shape by one. Scalars turn into singletons. -- -- >>> incAt 1 [2,3,4] -- [2,4,4] +-- >>> incAt 0 [] +-- [2] incAt :: Int -> [Int] -> [Int] -incAt d ds = modifyDim d (+1) ds +incAt d ds = modifyDim d (+1) (asSingleton ds) --- | Increment the index at a dimension of a shape by one. +-- | Increment the index at a dimension of a shape by one. Scalars turn into singletons. -- -- >>> :k! Eval (IncAt 1 [2,3,4]) -- ... -- = [2, 4, 4] -data IncAt :: Nat -> t Nat -> Exp (t Nat) +-- >>> :k! Eval (IncAt 0 '[]) +-- ... +-- = '[2] +data IncAt :: Nat -> [Nat] -> Exp [Nat] type instance Eval (IncAt d ds) = - Eval (ModifyDim d ((Fcf.+) 1) ds) + Eval (ModifyDim d ((Fcf.+) 1) (Eval (AsSingleton ds))) -- | Decrement the index at a dimension os a shape by one. -- @@ -758,7 +647,7 @@ decAt d ds = modifyDim d (\x -> x - 1) ds -- >>> :k! Eval (DecAt 1 [2,3,4]) -- ... -- = [2, 2, 4] -data DecAt :: Nat -> t Nat -> Exp (t Nat) +data DecAt :: Nat -> [Nat] -> Exp [Nat] type instance Eval (DecAt d ds) = Eval (ModifyDim d (Flip (Fcf.-) 1) ds) @@ -791,12 +680,12 @@ preDeletePositions as = reverse (go as []) -- >>> :k! Eval (PreDeletePositions [1,2,0]) -- ... -- = [1, 1, 0] -data PreDeletePositions :: t Nat -> Exp (t Nat) +data PreDeletePositions :: [Nat] -> Exp [Nat] type instance Eval (PreDeletePositions xs) = Eval (Reverse (Eval (PreDeletePositionsGo xs '[]))) -data PreDeletePositionsGo :: t Nat -> t Nat -> Exp (t Nat) +data PreDeletePositionsGo :: [Nat] -> [Nat] -> Exp [Nat] type instance Eval (PreDeletePositionsGo '[] rs) = rs type instance Eval (PreDeletePositionsGo (x : xs) r) = @@ -832,7 +721,7 @@ preInsertPositions = reverse . preDeletePositions . reverse -- >>> :k! Eval (PreInsertPositions [1,2,0]) -- ... -- = [0, 1, 0] -data PreInsertPositions :: t Nat -> Exp (t Nat) +data PreInsertPositions :: [Nat] -> Exp [Nat] type instance Eval (PreInsertPositions xs) = Eval (Reverse =<< (PreDeletePositions =<< (Reverse xs))) @@ -849,7 +738,7 @@ deleteDims i s = foldl' (flip deleteDim) s (preDeletePositions i) -- >>> :k! Eval (DeleteDims [1,0] [2, 3, 4]) -- ... -- = '[4] -data DeleteDims :: t Nat -> t Nat -> Exp (t Nat) +data DeleteDims :: [Nat] -> [Nat] -> Exp [Nat] type instance Eval (DeleteDims xs ds) = Eval (Foldl' (Flip DeleteDim) ds (Eval (PreDeletePositions xs))) @@ -861,77 +750,74 @@ type instance Eval (DeleteDims xs ds) = -- >>> insertDims [(1,3), (0,2)] [4] -- [2,3,4] insertDims :: [(Int,Int)] -> [Int] -> [Int] -insertDims ps ds = foldl' (flip (uncurry insertDim)) ds ps' +insertDims ps s = foldl' (flip (uncurry insertDim)) s ps' where ps' = zip (preInsertPositions $ fmap fst ps) (fmap snd ps) -- | insert a list of dimensions according to dimension,position tuple lists. Note that the list of positions references the final shape and not the initial shape. -- --- >>> :k! Eval (InsertDims '[ '(0,5)] '[]) +-- >>> import Fcf +-- >>> :k! Eval (Foldl' (Flip InsertDimUncurried) '[] (Eval (Zip (Eval (PreInsertPositions '[0])) '[5]))) +-- Eval (Foldl' (Flip InsertDimUncurried) '[] (Eval (Zip (Eval (PreInsertPositions '[0])) '[5]))) :: [Natural] +-- = '[5] +-- >>> :k! Eval (InsertDims '[0] '[5] '[]) -- ... -- = '[5] --- >>> :k! Eval (InsertDims '[ '(1, 3), '(0, 2)] '[4]) +-- >>> :k! Eval (InsertDims [1,0] [3,2] '[4]) -- ... -- = [2, 3, 4] -data InsertDims :: t (Nat, Nat) -> t Nat -> Exp (t Nat) +data InsertDims :: [Nat] -> [Nat] -> [Nat] -> Exp [Nat] -type instance Eval (InsertDims ps ds) = - Eval (Foldl' (Flip InsertDimUncurried) ds (Eval (InsertDimsHelper ps))) - -data InsertDimsHelper :: t (Nat, Nat) -> Exp (t (Nat, Nat)) - -type instance Eval (InsertDimsHelper ps) = - Eval (Zip (Eval (PreInsertPositions (Eval (Map Fst ps)))) (Eval (Map Snd ps))) +type instance Eval (InsertDims ds xs s) = + Eval (Foldl' (Flip InsertDimUncurried) s (Eval (Zip (Eval (PreInsertPositions ds)) xs))) -- | replace indexes with a new value according to a dimension list. -- --- >>> replaceDims [0,1] [1,5] [2,3,4] +-- >>> setDims [0,1] [1,5] [2,3,4] -- [1,5,4] -- --- >>> replaceDims [0] [3] [] +-- >>> setDims [0] [3] [] -- [3] -replaceDims :: [Int] -> [Int] -> [Int] -> [Int] -replaceDims ds xs ns = foldl' (\ns' (d, x) -> replaceDim d x ns') ns (zip ds xs) +setDims :: [Int] -> [Int] -> [Int] -> [Int] +setDims ds xs ns = foldl' (\ns' (d, x) -> setDim d x ns') ns (zip ds xs) -- | replace indexes with a new value according to a dimension list. -- --- FIXME: Why is this different to value-level result? --- --- >>> :k! Eval (ReplaceDims [0,1] [1,5] [2,3,4]) +-- >>> :k! Eval (SetDims [0,1] [1,5] [2,3,4]) -- ... -- = [1, 5, 4] -- --- >>> :k! Eval (ReplaceDims '[0] '[3] '[]) +-- >>> :k! Eval (SetDims '[0] '[3] '[]) -- ... --- = '[] -data ReplaceDims :: t Nat -> t Nat -> t Nat -> Exp (t Nat) +-- = '[3] +data SetDims :: [Nat] -> [Nat] -> [Nat] -> Exp [Nat] -type instance Eval (ReplaceDims ds xs ns) = - Eval (Foldl' (Flip ReplaceDimUncurried) ns (Eval (Zip ds xs))) +type instance Eval (SetDims ds xs ns) = + Eval (Foldl' (Flip SetDimUncurried) ns (Eval (Zip ds xs))) -- | replace indexes dimension,value tuple list. -- --- >>> replaceDimsT [(0,1),(1,5)] [2,3,4] +-- >>> setDimsT [(0,1),(1,5)] [2,3,4] -- [1,5,4] -- --- >>> replaceDimsT [(0,3)] [] +-- >>> setDimsT [(0,3)] [] -- [3] -replaceDimsT :: [(Int, Int)] -> [Int] -> [Int] -replaceDimsT ts ns = foldl' (\ns' (d, x) -> replaceDim d x ns') ns ts +setDimsT :: [(Int, Int)] -> [Int] -> [Int] +setDimsT ts ns = foldl' (\ns' (d, x) -> setDim d x ns') ns ts -- | replace indexes with a new value according to a dimension list. -- --- >>> :k! Eval (ReplaceDims [0,1] [1,5] [2,3,4]) +-- >>> :k! Eval (SetDims [0,1] [1,5] [2,3,4]) -- ... -- = [1, 5, 4] -- --- >>> :k! Eval (ReplaceDims '[0] '[3] '[]) +-- >>> :k! Eval (SetDims '[0] '[3] '[]) -- ... --- = '[] -data ReplaceDimsT :: t (Nat,Nat) -> t Nat -> Exp (t Nat) +-- = '[3] +data SetDimsT :: t (Nat,Nat) -> [Nat] -> Exp [Nat] -type instance Eval (ReplaceDimsT ts ns) = - Eval (Foldl' (Flip ReplaceDimUncurried) ns ts) +type instance Eval (SetDimsT ts ns) = + Eval (Foldl' (Flip SetDimUncurried) ns ts) -- | modify indexes with (separate) functions according to a dimension list. -- @@ -958,17 +844,17 @@ takeDims i s = (s List.!!) <$> i -- >>> :k! Eval (TakeDims '[2] '[]) -- ... -- = '[(TypeError ...)] -data TakeDims :: t Nat -> t Nat -> Exp (t Nat) +data TakeDims :: [Nat] -> [Nat] -> Exp [Nat] type instance Eval (TakeDims xs ds) = - Eval (Map (Flip UnsafeGetIndex ds) xs) + Eval (Map (Flip GetDim ds) xs) -- | Compute new size given a drop,n tuple list -- -- >>> dropDims [(0,1),(2,3)] [2,3,4] -- [1,3,1] dropDims :: [(Int, Int)] -> [Int] -> [Int] -dropDims ts ds = replaceDimsT ts' ds +dropDims ts ds = setDimsT ts' ds where xs' = zipWith (-) (takeDims (fmap fst ts) ds) (fmap snd ts) ts' = zip (fmap fst ts) xs' @@ -981,7 +867,7 @@ dropDims ts ds = replaceDimsT ts' ds data DropDims :: [(Nat,Nat)] -> [Nat] -> Exp [Nat] type instance Eval (DropDims ts ds) = - Eval (ReplaceDimsT (Eval (Zip (Eval (Map Fst ts)) (Eval (ZipWith (Fcf.-) (Eval (TakeDims (Eval (Map Fst ts)) ds)) (Eval (Map Snd ts)))))) ds) + Eval (SetDimsT (Eval (Zip (Eval (Map Fst ts)) (Eval (ZipWith (Fcf.-) (Eval (TakeDims (Eval (Map Fst ts)) ds)) (Eval (Map Snd ts)))))) ds) -- | Turn a list of included positions for a given rank into a list of excluded positions -- @@ -995,7 +881,7 @@ exclude r xs = deleteDims xs [0 .. (r - 1)] -- > :k! Eval (Exclude 3 [1,2]) -- ... -- = '[0] -data Exclude :: Nat -> t Nat -> Exp (t Nat) +data Exclude :: Nat -> [Nat] -> Exp [Nat] type instance Eval (Exclude r xs) = Eval (DeleteDims (Eval (EnumFromTo 0 (Eval ((Fcf.-) r 1)))) xs) @@ -1017,17 +903,39 @@ concatenate :: Int -> [Int] -> [Int] -> [Int] concatenate _ [] [] = [2] concatenate _ [] [x] = [x + 1] concatenate _ [x] [] = [x + 1] -concatenate i s0 s1 = take i s0 ++ (unsafeGetIndex i s0 + unsafeGetIndex i s1 : drop (i + 1) s0) +concatenate i s0 s1 = take i s0 ++ (getDim i s0 + getDim i s1 : drop (i + 1) s0) -data Concatenate :: Nat -> t Nat -> t Nat -> Exp (t Nat) +-- | concatenate two arrays at dimension i +-- +-- Bespoke logic for scalars. +-- +-- >>> :k! Eval (Concatenate 1 [2,3,4] [2,3,4]) +-- ... +-- = [2, 6, 4] +-- >>> :k! Eval (Concatenate 0 '[3] '[]) +-- ... +-- = '[4] +-- >>> :k! Eval (Concatenate 0 '[] '[3]) +-- ... +-- = '[4] +-- >>> :k! Eval (Concatenate 0 '[] '[]) +-- ... +-- = '[2] +data Concatenate :: Nat -> [Nat] -> [Nat] -> Exp [Nat] type instance Eval (Concatenate i s0 s1) = + Eval (ConcatenateHelper i (Eval (AsSingleton s0)) (Eval (AsSingleton s1))) + +data ConcatenateHelper :: Nat -> [Nat] -> [Nat] -> Exp [Nat] + + +type instance Eval (ConcatenateHelper i s0 s1) = If (Eval (ConcatenateOk i s0 s1)) - (Eval (Eval (Take i s0) ++ (Eval (UnsafeGetIndex i s0) + Eval (UnsafeGetIndex i s1) : Eval (Drop (i + 1) s0)))) + (Eval (Eval (Take i s0) ++ (Eval (GetDim i s0) + Eval (GetDim i s1) : Eval (Drop (i + 1) s0)))) (L.TypeError (L.Text "Concatenate Mis-matched shapes.")) -- | Concatenate is Ok if ranks are the same and the non-indexed portion of the shapes are the same. -data ConcatenateOk :: Nat -> t Nat -> t Nat -> Exp Bool +data ConcatenateOk :: Nat -> [Nat] -> [Nat] -> Exp Bool type instance Eval (ConcatenateOk i s0 s1) = Eval (IsFin i (Eval (Rank s0))) @@ -1041,21 +949,21 @@ type instance Eval (ConcatenateOk i s0 s1) = reorder :: [Int] -> [Int] -> [Int] reorder [] _ = [] reorder _ [] = [] -reorder s (d : ds) = unsafeGetIndex d s : reorder s ds +reorder s (d : ds) = getDim d s : reorder s ds -- | Reorder the dimensions of shape according to a list of positions. -- -- >>> :k! Eval (Reorder [2,3,4] [2,0,1]) -- ... -- = [4, 2, 3] -data Reorder :: t Nat -> t Nat -> Exp (t Nat) +data Reorder :: [Nat] -> [Nat] -> Exp [Nat] type instance Eval (Reorder ds xs) = If ( Eval (ReorderOk ds xs)) - (Eval (Map (Flip UnsafeGetIndex ds) xs)) + (Eval (Map (Flip GetDim ds) xs)) (L.TypeError ('Text "Reorder dimension indices out of bounds")) -data ReorderOk :: t Nat -> t Nat -> Exp Bool +data ReorderOk :: [Nat] -> [Nat] -> Exp Bool type instance Eval (ReorderOk ds xs) = Eval (TyEq (Eval (Rank ds)) (Eval (Rank xs))) && @@ -1186,8 +1094,103 @@ data Foldl' :: (b -> a -> Exp b) -> b -> t a -> Exp b type instance Eval (Foldl' f y '[]) = y type instance Eval (Foldl' f y (x ': xs)) = Eval (Foldl' f (Eval (f y x)) xs) +-- | Get an element at a given index. +-- +-- >>> :kind! Eval (GetIndex 2 [2,3,4]) +-- ... +-- = Just 4 +data GetIndex :: Nat -> [a] -> Exp (Maybe a) + +type instance Eval (GetIndex d xs) = GetIndexImpl d xs + +type family GetIndexImpl (n :: Nat) (xs :: [k]) where + GetIndexImpl _ '[] = 'Nothing + GetIndexImpl 0 (x ': _) = 'Just x + GetIndexImpl n (_ ': xs) = GetIndexImpl (n - 1) xs + +-- | getDim i xs is the i'th element of xs. getDim 0 [] is 1 (to account for scalars). Error if out-of-bounds. +-- +-- >>> getDim 1 [2,3,4] +-- 3 +-- >>> getDim 3 [2,3,4] +-- *** Exception: getDim outside bounds +-- ... +-- >>> getDim 0 [] +-- 1 +getDim :: Int -> [Int] -> Int +getDim 0 [] = 1 +getDim i s = fromMaybe (error "getDim outside bounds") (s List.!? i) + +-- | GetDim i xs is the i'th element of xs. getDim 0 [] is 1 (to account for scalars). Error if out-of-bounds or non-computable (usually unknown to the compiler). +-- +-- >>> :k! Eval (GetDim 1 [2,3,4]) +-- ... +-- = 3 +-- >>> :k! Eval (GetDim 3 [2,3,4]) +-- ... +-- = (TypeError ...) +-- >>> :k! Eval (GetDim 0 '[]) +-- ... +-- = 1 +data GetDim :: Nat -> [Nat] -> Exp Nat +type instance Eval (GetDim n xs) = + If (Eval (And [(Eval (TyEq n 0)), (Eval (TyEq xs ('[]::[Nat])))])) 1 + (Eval (FromMaybe (L.TypeError (L.Text "GetDim out of bounds or non-computable: " :<>: ShowType n :<>: L.Text " " :<>: ShowType xs)) (Eval (GetIndex n xs)))) + +-- | modify an index at a specific dimension. Errors if out of bounds. +-- +-- >>> modifyDim 0 (+1) [0,1,2] +-- [1,1,2] +-- >>> modifyDim 0 (+1) [] +-- [2] +modifyDim :: Int -> (Int -> Int) -> [Int] -> [Int] +modifyDim 0 f [] = [f 1] +modifyDim d f xs = + getDim d xs & + f & + (:drop (d+1) xs) & + (take d xs <>) + +-- | modify an index at a specific dimension. Errors if out of bounds. +-- +-- >>> :k! Eval (ModifyDim 0 ((Fcf.+) 1) [0,1,2]) +-- ... +-- = [1, 1, 2] +data ModifyDim :: Nat -> (Nat -> Exp Nat) -> [Nat] -> Exp [Nat] + +type instance Eval (ModifyDim d f s) = + Eval ( LiftM2 (Fcf.++) (Take d s) (LiftM2 Cons (f =<< (GetDim d s)) (Drop (d+1) s))) + +-- | replace an index at a specific dimension, or transform a scalar into being one-dimensional. +-- +-- >>> setDim 0 1 [2,3,4] +-- [1,3,4] +-- >>> setDim 0 3 [] +-- [3] +setDim :: Int -> Int -> [Int] -> [Int] +setDim d x xs = modifyDim d (const x) xs + +-- | replace an index at a specific dimension. +-- +-- >>> :k! Eval (SetDim 0 1 [2,3,4]) +-- ... +-- = [1, 3, 4] +data SetDim :: Nat -> Nat -> [Nat] -> Exp [Nat] + +type instance Eval (SetDim d x ds) = + Eval (ModifyDim d (ConstFn x) ds) + +data SetDimUncurried :: (Nat,Nat) -> [Nat] -> Exp [Nat] + +type instance Eval (SetDimUncurried xs ds) = + Eval (SetDim (Eval (Fst xs)) (Eval (Snd xs)) ds) + +-- | Take along a dimension. +-- +-- >>> takeDim 0 1 [2,3,4] +-- [1,3,4] takeDim :: Int -> Int -> [Int] -> [Int] -takeDim d t s = replaceDim d (min t (unsafeGetIndex d s)) s +takeDim d t s = modifyDim d (min t) s -- | Take along a dimension. -- @@ -1198,20 +1201,74 @@ data TakeDim :: Nat -> Nat -> [Nat] -> Exp [Nat] type instance Eval (TakeDim d t s) = Eval ( - Flip (SetIndex d) s =<< - Min t =<< - UnsafeGetIndex d s + ModifyDim d (Min t) s ) -someTakeDim :: forall d t s. (HasShape s, KnownNat d, KnownNat t) => SNats (Eval (TakeDim d t s)) -someTakeDim = withSomeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) unsafeCoerce +-- | Drop along a dimension. +-- +-- >>> dropDim 2 1 [2,3,4] +-- [2,3,3] +dropDim :: Int -> Int -> [Int] -> [Int] +dropDim d t s = modifyDim d (max 0 . (\x -> x - t)) s -someTakeDim' :: forall d t s s'. (HasShape s, KnownNat d, KnownNat t) => SNats s' -someTakeDim' = withSomeSNats (fromIntegral <$> takeDim (int (SNat :: SNat d)) (int (SNat :: SNat t)) (ints (SNats :: SNats s))) unsafeCoerce +-- | Drop along a dimension. +-- +-- >>> :k! Eval (DropDim 2 1 [2,3,4]) +-- ... +-- = [2, 3, 3] +data DropDim :: Nat -> Nat -> [Nat] -> Exp [Nat] --- | Provide axiomatic proof, make sure you wrote it on paper! -unsafeAxiom :: Dict c -unsafeAxiom = unsafeCoerce (Dict :: Dict ()) +type instance Eval (DropDim d t s) = + Eval ( + ModifyDim d + (Max 0 <=< (Flip (Fcf.-) t)) + s) + +-- | delete the i'th dimension. No effect on a scalar. +-- +-- >>> deleteDim 1 [2, 3, 4] +-- [2,4] +-- >>> deleteDim 2 [] +-- [] +deleteDim :: Int -> [Int] -> [Int] +deleteDim i s = take i s ++ drop (i + 1) s -axiomTakeDim :: forall d t s. (KnownNat d, KnownNat t, KnownNats s) :- (KnownNats (Eval (TakeDim d t s))) -axiomTakeDim = Sub unsafeAxiom +-- | delete the i'th dimension +-- +-- >>> :k! Eval (DeleteDim 1 [2, 3, 4]) +-- ... +-- = [2, 4] +-- >>> :k! Eval (DeleteDim 1 '[]) +-- ... +-- = '[] +data DeleteDim :: Nat -> [Nat] -> Exp [Nat] + +type instance Eval (DeleteDim i ds) = + Eval (LiftM2 (Fcf.++) (Take i ds) (Drop (i + 1) ds)) + +-- | Insert a new dimension at a position (or at the end if > rank). +-- +-- >>> insertDim 1 3 [2,4] +-- [2,3,4] +-- >>> insertDim 0 4 [] +-- [4] +insertDim :: Int -> Int -> [Int] -> [Int] +insertDim d i s = take d s ++ (i : drop d s) + +-- | Insert a new dimension at a position (or at the end if > rank). +-- +-- >>> :k! Eval (InsertDim 1 3 [2,4]) +-- ... +-- = [2, 3, 4] +-- >>> :k! Eval (InsertDim 0 4 '[]) +-- ... +-- = '[4] +data InsertDim :: Nat -> Nat -> [Nat] -> Exp [Nat] + +type instance Eval (InsertDim d i ds) = + Eval (LiftM2 (Fcf.++) (Take d ds) ((Cons i) =<< (Drop d ds))) + +data InsertDimUncurried :: (Nat,Nat) -> [Nat] -> Exp [Nat] + +type instance Eval (InsertDimUncurried xs ds) = + Eval (InsertDim (Eval (Fst xs)) (Eval (Snd xs)) ds)