From 4fc40b4dffed9c135a54660a8f9867381988c502 Mon Sep 17 00:00:00 2001 From: Tony Day Date: Tue, 20 Aug 2024 08:09:27 +1000 Subject: [PATCH] upto insert --- src/NumHask/Array.hs | 2 +- src/NumHask/Array/Dynamic.hs | 6 +- src/NumHask/Array/Fixed.hs | 166 ++++++++++++++++++---- src/NumHask/Array/Shape.hs | 259 ++++++++++++++++++++--------------- 4 files changed, 295 insertions(+), 138 deletions(-) diff --git a/src/NumHask/Array.hs b/src/NumHask/Array.hs index 2f8d0c1..18a87a0 100644 --- a/src/NumHask/Array.hs +++ b/src/NumHask/Array.hs @@ -11,7 +11,7 @@ module NumHask.Array where import NumHask.Array.Fixed -import NumHask.Array.Shape hiding (asSingleton, asScalar, concatenate, rank, reorder, size, squeeze) +import NumHask.Array.Shape hiding (asSingleton, asScalar, concatenate, rank, reorder, size, squeeze, rotate) -- $imports -- diff --git a/src/NumHask/Array/Dynamic.hs b/src/NumHask/Array/Dynamic.hs index ce213db..01b5736 100644 --- a/src/NumHask/Array/Dynamic.hs +++ b/src/NumHask/Array/Dynamic.hs @@ -810,7 +810,7 @@ append :: Array a -> Array a -> Array a -append d a b = insert d (S.indexOf d (shape a)) a b +append d a b = insert d (S.unsafeGetIndex d (shape a)) a b -- | Insert along a dimension at the beginning. -- @@ -957,8 +957,6 @@ slices :: Array a slices ps a = dimsWise slice ps a --- * application - -- | Reduce along specified dimensions, using the supplied fold. -- -- >>> pretty $ reduces [0] sum a @@ -1610,7 +1608,7 @@ concats :: concats ds n a = backpermute concatDims unconcatDims a where concatDims s = S.insertDim n (S.size $ S.takeDims ds s) (S.deleteDims ds s) - unconcatDims s = S.insertDims ds (S.shapen (S.takeDims ds (shape a)) (S.indexOf n s)) (S.deleteDim n s) + unconcatDims s = S.insertDims ds (S.shapen (S.takeDims ds (shape a)) (S.unsafeGetIndex n s)) (S.deleteDim n s) -- | Rotate an array along a dimension. -- diff --git a/src/NumHask/Array/Fixed.hs b/src/NumHask/Array/Fixed.hs index fd37730..1f43b3b 100644 --- a/src/NumHask/Array/Fixed.hs +++ b/src/NumHask/Array/Fixed.hs @@ -75,6 +75,7 @@ module NumHask.Array.Fixed -- ** Single-dimension operators take, slice, + insert, -- * Operators takes, @@ -106,7 +107,11 @@ module NumHask.Array.Fixed -- * Shape manipulations reshape, + reorder, squeeze, + reverses, + rotate, + rotates, {- @@ -165,7 +170,7 @@ import Data.Vector qualified as V import Fcf hiding (type (&&), type (+), type (-), type (++)) import GHC.TypeNats import NumHask.Array.Dynamic qualified as D -import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze) +import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze, rotate, reorder) import NumHask.Array.Shape qualified as S import NumHask.Prelude as P hiding (Min, take, diff, zipWith, empty, sequence, toList, length) import Prettyprinter hiding (dot) @@ -629,13 +634,14 @@ diag a = tabulate (go . fromFins) -- [0,1,0], -- [0,0,2]] undiag :: - forall a s. + forall s' a s. ( HasShape s, - Additive a, - HasShape ((++) s s) + HasShape s', + s' ~ Eval ((++) s s), + Additive a ) => Array s a -> - Array ((++) s s) a + Array s' a undiag a = tabulate (go . fromFins) where go [] = index a (UnsafeFins []) @@ -746,9 +752,43 @@ slice :: Proxy l -> Array s a -> Array s' a - slice _ o _ a = unsafeBackpermute (S.modifyDim (valueOf @d) (+ o)) a +-- | Insert along a dimension at a position. +-- +-- >>> pretty $ insert (Proxy :: Proxy 2) 0 a (konst @[2,3] 0) +-- [[[0,0,1,2,3], +-- [0,4,5,6,7], +-- [0,8,9,10,11]], +-- [[0,12,13,14,15], +-- [0,16,17,18,19], +-- [0,20,21,22,23]]] +-- >>> toDynamic $ insert (Proxy :: Proxy 0) 0 (toScalar 1) (toScalar 2) +-- UnsafeArray [2] [2,1] +insert :: + forall s' s si d a. + (KnownNat d, + HasShape s, + HasShape si, + HasShape s', + HasShape (Eval (AsSingleton s)), + HasShape (Eval (AsSingleton si)), + s' ~ Eval (IncAt d (Eval (AsSingleton s))) + ) => + Proxy d -> + Int -> + Array s a -> + Array si a -> + Array s' a +insert _ i a b = tabulate go + where + go xs + | xs' !! d == i = index (asSingleton b) (UnsafeFins (S.deleteDim d xs')) + | xs' !! d < i = index (asSingleton a) (UnsafeFins xs') + | otherwise = index (asSingleton a) (UnsafeFins (S.decAt d xs')) + where xs' = fromFins xs + d = valueOf @d + -- | Takes the top-most elements according to the new dimensions. -- -- >>> pretty (takes @[1,2,2] a) @@ -1114,15 +1154,16 @@ zips ds f a b = joins ds (zipWith f (extracts ds a) (extracts ds b)) -- [[([1,1],[0,0]),([1,1],[0,1])], -- [([1,1],[1,0]),([1,1],[1,1])]]]] expand :: - forall s s' a b c. - ( HasShape s, - HasShape s', - HasShape ((++) s s') + forall sc sa sb a b c. + ( HasShape sa, + HasShape sb, + HasShape sc, + sc ~ Eval ((++) sa sb) ) => (a -> b -> c) -> - Array s a -> - Array s' b -> - Array ((++) s s') c + Array sa a -> + Array sb b -> + Array sc c expand f a b = tabulate (\i -> f (index a (UnsafeFins $ List.take r (fromFins i))) (index b (UnsafeFins $ drop r (fromFins i)))) where r = rank a @@ -1139,15 +1180,16 @@ expand f a b = tabulate (\i -> f (index a (UnsafeFins $ List.take r (fromFins i) -- [(0,4),(1,4),(2,4)], -- [(0,5),(1,5),(2,5)]] expandr :: - forall s s' a b c. - ( HasShape s, - HasShape s', - HasShape ((++) s s') + forall sc sa sb a b c. + ( HasShape sa, + HasShape sb, + HasShape sc, + sc ~ Eval ((++) sa sb) ) => (a -> b -> c) -> - Array s a -> - Array s' b -> - Array ((++) s s') c + Array sa a -> + Array sb b -> + Array sc c expandr f a b = tabulate (\i -> f (index a (UnsafeFins $ drop r (fromFins i))) (index b (UnsafeFins $ List.take r (fromFins i)))) where r = rank a @@ -1204,15 +1246,15 @@ dot :: forall a b c d sa sb s' ss se. ( HasShape sa, HasShape sb, - HasShape (sa ++ sb), - se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb), + HasShape (Eval ((++) sa sb)), + se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)), HasShape se, KnownNat (Eval (Minimum se)), KnownNat (Eval (Rank sa) - 1), KnownNat (Eval (Rank sa)), ss ~ '[Eval (Minimum se)], HasShape ss, - s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb), + s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)), HasShape s' ) => (Array ss c -> d) -> @@ -1250,15 +1292,15 @@ mult :: Multiplicative a, HasShape sa, HasShape sb, - HasShape (sa ++ sb), - se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb), + HasShape (Eval ((++) sa sb)), + se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)), HasShape se, KnownNat (Eval (Minimum se)), KnownNat (Eval (Rank sa) - 1), KnownNat (Eval (Rank sa)), ss ~ '[Eval (Minimum se)], HasShape ss, - s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb), + s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)), HasShape s' ) => Array sa a -> @@ -1294,6 +1336,29 @@ reshape = unsafeBackpermute (shapen s . flatten s') s = shapeOf @s s' = shapeOf @s' +-- | Change the order of dimensions. +-- +-- >>> pretty $ reorder (Proxy :: Proxy [2,0,1]) a +-- [[[0,4,8], +-- [12,16,20]], +-- [[1,5,9], +-- [13,17,21]], +-- [[2,6,10], +-- [14,18,22]], +-- [[3,7,11], +-- [15,19,23]]] +reorder :: + forall dims s s' a. + (HasShape s, + HasShape s', + HasShape dims, + s' ~ Eval (Reorder s dims) + ) => + Proxy dims -> + Array s a -> + Array s' a +reorder _ a = unsafeBackpermute (\s -> S.insertDims (shapeOf @dims) s []) a + -- | Remove single dimensions. -- -- >>> let sq = array [1..24] :: Array '[2,1,3,4,1] Int @@ -1341,6 +1406,55 @@ squeeze :: Array t a squeeze = unsafeModifyShape +-- | Reverses element order along specified dimensions. +-- +-- >>> pretty $ reverses [0,1] a +-- [[[20,21,22,23], +-- [16,17,18,19], +-- [12,13,14,15]], +-- [[8,9,10,11], +-- [4,5,6,7], +-- [0,1,2,3]]] +reverses :: + (HasShape s) => + [Int] -> + Array s a -> + Array s a +reverses ds a = unsafeBackpermute (S.reverseIndex ds (shape a)) a + +-- | Rotate an array along a dimension. +-- +-- >>> pretty $ rotate 1 2 a +-- [[[8,9,10,11], +-- [0,1,2,3], +-- [4,5,6,7]], +-- [[20,21,22,23], +-- [12,13,14,15], +-- [16,17,18,19]]] +rotate :: + (HasShape s) => + Int -> + Int -> + Array s a -> + Array s a +rotate d r a = unsafeBackpermute (S.modifyDim d (\i -> (r + i) `mod` (shape a !! d))) a + +-- | Rotate an array by/along offset,dimension tuples. +-- +-- >>> pretty $ rotates [(1, 2)] a +-- [[[8,9,10,11], +-- [0,1,2,3], +-- [4,5,6,7]], +-- [[20,21,22,23], +-- [12,13,14,15], +-- [16,17,18,19]]] +rotates :: + forall a s. + (HasShape s) => + [(Int, Int)] -> + Array s a -> + Array s a +rotates rs a = unsafeBackpermute (rotateIndex rs (shapeOf @s)) a {- -- | Reshape an array (with the same number of elements). diff --git a/src/NumHask/Array/Shape.hs b/src/NumHask/Array/Shape.hs index efbc363..c3b60a7 100644 --- a/src/NumHask/Array/Shape.hs +++ b/src/NumHask/Array/Shape.hs @@ -52,6 +52,8 @@ module NumHask.Array.Shape asScalar, AsScalar, GetIndex, + unsafeGetIndex, + UnsafeGetIndex, rotate, type (++), type (!!), @@ -64,23 +66,21 @@ module NumHask.Array.Shape rerank, size, Size, - indexOf, - IndexOf, Min, minimum, Minimum, - checkIndex, - CheckIndex, - checkIndexes, - CheckIndexes, - reverseIndex, modifyDim, - rotateIndex, + ModifyDim, + replaceDim, + ReplaceDim, + incAt, + IncAt, + decAt, + DecAt, insertDim, InsertDim, deleteDim, DeleteDim, - replaceDim, preDeletePositions, preInsertPositions, PosRelative, @@ -106,14 +106,20 @@ module NumHask.Array.Shape CheckInsert, reorder, Reorder, - CheckReorder, + ReorderOk, squeeze, Squeeze, - incAt, - decAt, Zip, Windows, Fcf.Eval, + + -- * Assertions + checkIndex, + CheckIndex, + + -- * index-only operations + reverseIndex, + rotateIndex, ) where @@ -136,7 +142,7 @@ import Unsafe.Coerce import Fcf hiding (type (&&), type (+), type (-), type (++)) import Fcf qualified import Fcf.Class.Foldable -import Fcf.Data.List (Take, Drop) +import Fcf.Data.List import Control.Monad -- $setup @@ -152,6 +158,7 @@ import Control.Monad -- | Get the value of a type level Nat. -- Use with explicit type application -- + -- >>> valueOf @42 -- 42 valueOf :: forall n. (KnownNat n) => Int @@ -438,22 +445,28 @@ type family GetIndexImpl (n :: Nat) (xs :: [k]) where GetIndexImpl 0 (x ': _) = 'Just x GetIndexImpl n (_ ': xs) = GetIndexImpl (n - 1) xs --- | indexOf i xs is the i'th element of xs (or zero if out-of-bounds) +-- | UnsafeGetIndex i xs is the i'th element of xs (or error if out-of-bounds) -- --- >>> indexOf 1 [2,3,4] --- 3 -indexOf :: Int -> [Int] -> Int -indexOf 0 (s : _) = s -indexOf n (_ : s) = indexOf (n - 1) s -indexOf _ _ = error "indexOf outside bounds" - -type family IndexOf (i :: Nat) (xs :: [Nat]) :: Nat where - IndexOf 0 (xs : _) = xs - IndexOf n (_ : xs) = IndexOf (n - 1) xs - IndexOf _ _ = L.TypeError ('Text "indexOf outside bounds") +-- >>> :k! Eval (UnsafeGetIndex 1 [2,3,4]) +-- ... +-- = 3 +-- >>> :k! Eval (UnsafeGetIndex 3 [2,3,4]) +-- ... +-- = (TypeError ...) +data UnsafeGetIndex :: Nat -> [a] -> Exp a +type instance Eval (UnsafeGetIndex n xs) = Eval (FromMaybe (L.TypeError (L.Text "UnsafeGetIndex out of bounds")) (Eval (GetIndex n xs))) --- type family Min (x :: Nat) (y :: Nat) :: Nat where --- Min x y = If (x <=? y) x y +-- | unsafeGetIndex i xs is the i'th element of xs (or error if out-of-bounds) +-- +-- >>> unsafeGetIndex 1 [2,3,4] +-- 3 +-- >>> unsafeGetIndex 3 [2,3,4] +-- *** Exception: unsafeGetIndex outside bounds +-- ... +unsafeGetIndex :: Int -> [Int] -> Int +unsafeGetIndex 0 (s : _) = s +unsafeGetIndex n (_ : s) = unsafeGetIndex (n - 1) s +unsafeGetIndex _ _ = error "unsafeGetIndex outside bounds" -- | minimum dimension -- @@ -485,20 +498,6 @@ data Min :: a -> a -> Exp a type instance Eval (Min a b) = If (Eval (a Fcf.< b)) a b -type family Init (a :: [k]) :: [k] where - Init '[] = L.TypeError ('Text "No init") - Init '[_] = '[] - Init (x : xs) = x : Init xs - -type family Last (a :: [k]) :: k where - Last '[] = L.TypeError ('Text "No last") - Last '[x] = x - Last (_ : xs) = Last xs - -type family (a :: [k]) ++ (b :: [k]) :: [k] where - '[] ++ b = b - (a : as) ++ b = a : (as ++ b) - -- | delete the i'th dimension -- -- >>> deleteDim 1 [2, 3, 4] @@ -550,6 +549,16 @@ type instance Eval (InsertDim i d ds) = modifyDim :: Int -> (Int -> Int) -> [Int] -> [Int] modifyDim d f xs = take d xs <> (pure . f) (xs List.!! d) <> drop (d + 1) xs +-- | modify an index at a specific dimension. Unmodified if out of bounds. +-- +-- >>> :k! Eval (ModifyDim 0 ((Fcf.+) 1) [0,1,2]) +-- ... +-- = [1, 1, 2] +data ModifyDim :: Nat -> (Nat -> Exp Nat) -> [Nat] -> Exp [Nat] + +type instance Eval (ModifyDim d f ds) = + Eval (FromMaybe ds =<< (Map (Flip (SetIndex d) ds) =<< (Map f =<< (GetIndex d ds)))) + -- | replace an index at a specific dimension. -- -- >>> replaceDim 0 1 [2,3,4] @@ -557,25 +566,50 @@ modifyDim d f xs = take d xs <> (pure . f) (xs List.!! d) <> drop (d + 1) xs replaceDim :: Int -> Int -> [Int] -> [Int] replaceDim d x xs = modifyDim d (const x) xs --- | reverse an index along specific dimensions. +-- | replace an index at a specific dimension. -- --- >>> reverseIndex [0] [2,3,4] [0,1,2] --- [1,1,2] -reverseIndex :: [Int] -> [Int] -> [Int] -> [Int] -reverseIndex ds ns xs = fmap (\(i, x, n) -> bool x (n - 1 - x) (i `elem` ds)) (zip3 [0 ..] xs ns) +-- >>> :k! Eval (ReplaceDim 0 1 [2,3,4]) +-- ... +-- = [1, 3, 4] +data ReplaceDim :: Nat -> Nat -> [Nat] -> Exp [Nat] -type Reverse (a :: [k]) = ReverseGo a '[] +type instance Eval (ReplaceDim d x ds) = + Eval (SetIndex d x ds) -type family ReverseGo (a :: [k]) (b :: [k]) :: [k] where - ReverseGo '[] b = b - ReverseGo (a : as) b = ReverseGo as (a : b) +-- | Increment the index at a dimension of a shape by one. +-- +-- >>> incAt 1 [2,3,4] +-- [2,4,4] +incAt :: Int -> [Int] -> [Int] +incAt d ds = modifyDim d (+1) ds --- | rotate an index along specific dimensions. +-- | Increment the index at a dimension of a shape by one. -- --- >>> rotateIndex [(0,1)] [2,3,4] [0,1,2] --- [1,1,2] -rotateIndex :: [(Int, Int)] -> [Int] -> [Int] -> [Int] -rotateIndex rs s xs = foldr (\(d, r) acc -> modifyDim d (\x -> ((x + r) `mod`) (s List.!! d)) acc) xs rs +-- >>> :k! Eval (IncAt 1 [2,3,4]) +-- ... +-- = [2, 4, 4] +data IncAt :: Nat -> t Nat -> Exp (t Nat) + +type instance Eval (IncAt d ds) = + Eval (ModifyDim d ((Fcf.+) 1) ds) + +-- | Decrement the index at a dimension os a shape by one. +-- +-- >>> decAt 1 [2,3,4] +-- [2,2,4] +decAt :: Int -> [Int] -> [Int] +decAt d ds = modifyDim d (\x -> x - 1) ds + +-- | Decrement the index at a dimension of a shape by one. +-- +-- >>> :k! Eval (DecAt 1 [2,3,4]) +-- ... +-- = [2, 2, 4] +data DecAt :: Nat -> t Nat -> Exp (t Nat) + +type instance Eval (DecAt d ds) = + Eval (ModifyDim d (Flip (Fcf.-) 1) ds) + -- | Convert a list of position that reference deletions according to a final shape to one that references deletions relative to an initial shape. -- @@ -612,7 +646,7 @@ type family PosRelative (s :: [Nat]) where PosRelative s = PosRelativeGo s '[] type family PosRelativeGo (r :: [Nat]) (s :: [Nat]) where - PosRelativeGo '[] r = Reverse r + PosRelativeGo '[] r = Eval (Reverse r) PosRelativeGo (x : xs) r = PosRelativeGo (DecMap x xs) (x : r) type family DecMap (x :: Nat) (ys :: [Nat]) :: [Nat] where @@ -647,7 +681,7 @@ insertDims xs ys as = insertDimsGo (preInsertPositions xs) ys as insertDimsGo _ _ _ = throw (NumHaskException "mismatched ranks") type family InsertDims (xs :: [Nat]) (ys :: [Nat]) (as :: [Nat]) where - InsertDims xs ys as = InsertDimsGo (Reverse (PosRelative (Reverse xs))) ys as + InsertDims xs ys as = InsertDimsGo (Eval (Reverse (PosRelative (Eval (Reverse xs))))) ys as type family InsertDimsGo (xs :: [Nat]) (ys :: [Nat]) (as :: [Nat]) where InsertDimsGo '[] _ as' = as' @@ -717,27 +751,6 @@ exclude r xs = deleteDims xs [0 .. (r - 1)] type family Exclude (r :: Nat) (i :: [Nat]) where Exclude r i = DeleteDims (EnumerateGo r) i --- | /checkIndex i n/ checks if /i/ is a valid index of a list of length /n/ --- --- >>> checkIndex 0 0 --- True --- >>> checkIndex 3 2 --- False -checkIndex :: Int -> Int -> Bool -checkIndex i n = (zero <= i && i + one <= n) || (i == zero && n == zero) - -type family CheckIndex (i :: Nat) (n :: Nat) :: Bool where - CheckIndex i n = - If ((0 <=? i) && (i + 1 <=? n)) 'True (L.TypeError ('Text "index outside range")) - --- | /checkIndexes is n/ check if /is/ are valid indexes of a list of length /n/ -checkIndexes :: [Int] -> Int -> Bool -checkIndexes is n = all (`checkIndex` n) is - -type family CheckIndexes (i :: [Nat]) (n :: Nat) :: Bool where - CheckIndexes '[] _ = 'True - CheckIndexes (i : is) n = CheckIndex i n && CheckIndexes is n - -- | concatenate two arrays at dimension i -- -- Bespoke logic for scalars. @@ -754,29 +767,21 @@ concatenate :: Int -> [Int] -> [Int] -> [Int] concatenate _ [] [] = [2] concatenate _ [] [x] = [x + 1] concatenate _ [x] [] = [x + 1] -concatenate i s0 s1 = take i s0 ++ (indexOf i s0 + indexOf i s1 : drop (i + 1) s0) +concatenate i s0 s1 = take i s0 ++ (unsafeGetIndex i s0 + unsafeGetIndex i s1 : drop (i + 1) s0) -type Concatenate i s0 s1 = Eval (Take i s0) ++ (IndexOf i s0 + IndexOf i s1 : Eval (Drop (i + 1) s0)) +type Concatenate i s0 s1 = Eval (Take i s0) ++ (Eval (UnsafeGetIndex i s0) + Eval (UnsafeGetIndex i s1) : Eval (Drop (i + 1) s0)) type CheckConcatenate i s0 s1 s = - ( CheckIndex i (Eval (Rank s0)) + ( Eval (CheckIndex i (Eval (Rank s0))) && DeleteDim i s0 == DeleteDim i s1 && Rank s0 == Rank s1 ) ~ 'True type CheckInsert d i s = - (CheckIndex d (Eval (Rank s)) && CheckIndex i (IndexOf d s)) ~ 'True - -type Insert d s = Eval (Take d s) ++ (IndexOf d s + 1 : Eval (Drop (d + 1) s)) + (Eval (CheckIndex d (Eval (Rank s))) && Eval (CheckIndex i (Eval (UnsafeGetIndex d s)))) ~ 'True --- | /incAt d s/ increments the index at /d/ of shape /s/ by one. -incAt :: Int -> [Int] -> [Int] -incAt d s = take d s ++ (indexOf d s + 1 : drop (d + 1) s) - --- | /decAt d s/ decrements the index at /d/ of shape /s/ by one. -decAt :: Int -> [Int] -> [Int] -decAt d s = take d s ++ (indexOf d s - 1 : drop (d + 1) s) +type Insert d s = Eval (Take d s) ++ (Eval (UnsafeGetIndex d s) + 1 : Eval (Drop (d + 1) s)) -- | /reorder s i/ reorders the dimensions of shape /s/ according to a list of positions /i/ -- @@ -785,22 +790,20 @@ decAt d s = take d s ++ (indexOf d s - 1 : drop (d + 1) s) reorder :: [Int] -> [Int] -> [Int] reorder [] _ = [] reorder _ [] = [] -reorder s (d : ds) = indexOf d s : reorder s ds - -type family Reorder (s :: [Nat]) (ds :: [Nat]) :: [Nat] where - Reorder '[] _ = '[] - Reorder _ '[] = '[] - Reorder s (d : ds) = IndexOf d s : Reorder s ds - -type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where - CheckReorder ds s = - If - ( Rank ds == Rank s - && CheckIndexes ds (Eval (Rank s)) - ) - 'True - (L.TypeError ('Text "bad dimensions")) - ~ 'True +reorder s (d : ds) = unsafeGetIndex d s : reorder s ds + +data Reorder :: t Nat -> t Nat -> Exp (t Nat) + +type instance Eval (Reorder ds xs) = + If ( Eval (ReorderOk ds xs)) + (Eval (Map (Flip UnsafeGetIndex ds) xs)) + (L.TypeError ('Text "Reorder dimension indices out of bounds")) + +data ReorderOk :: t Nat -> t Nat -> Exp Bool + +type instance Eval (ReorderOk ds xs) = + Eval (TyEq (Eval (Rank ds)) (Eval (Rank xs))) && + Eval (And =<< Map (Flip CheckIndex (Eval (Rank ds))) xs) -- | remove 1's from a list -- @@ -844,3 +847,45 @@ type family Windows (ws :: [Nat]) (xs :: [Nat]) where -- c = List.length xs -- df s = List.zipWith (\s' x' -> s' - x' + 1) s xs <> xs <> List.drop c s + +-- | Check if i is a valid index of a dimension of length l +-- +-- >>> checkIndex 0 2 +-- True +-- >>> checkIndex 2 2 +-- False +checkIndex :: Int -> Int -> Bool +checkIndex i n = (zero <= i && i + one <= n) + +-- | Check if i is a valid index of a dimension of length l +-- FIXME: rename to In +-- +-- >>> :k! Eval (CheckIndex 0 2) +-- ... +-- = True +-- >>> :k! Eval (CheckIndex 2 2) +-- ... +-- = False +data CheckIndex :: Nat -> Nat -> Exp Bool + +type instance Eval (CheckIndex x d) = + Eval ((Fcf.<) x d) + +-- | reverse an index along specific dimensions. +-- +-- >>> reverseIndex [0] [2,3,4] [0,1,2] +-- [1,1,2] +reverseIndex :: [Int] -> [Int] -> [Int] -> [Int] +reverseIndex ds ns xs = fmap (\(i, x, n) -> bool x (n - 1 - x) (i `elem` ds)) (zip3 [0 ..] xs ns) + +-- | rotate an index along specific dimensions. +-- +-- >>> rotateIndex [(0,1)] [2,3,4] [0,1,2] +-- [1,1,2] +rotateIndex :: [(Int, Int)] -> [Int] -> [Int] -> [Int] +rotateIndex rs s xs = foldr (\(d, r) acc -> modifyDim d (\x -> ((x + r) `mod`) (s List.!! d)) acc) xs rs + +data EnumFromTo :: Nat -> Nat -> Exp (t Nat) + +type instance Eval (EnumFromTo a b) = + If (Eval (a Fcf.> b)) '[] (a : Eval (EnumFromTo (a+1) b))