Skip to content

Commit

Permalink
upto insert
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyday567 committed Aug 19, 2024
1 parent 5313661 commit 4fc40b4
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 138 deletions.
2 changes: 1 addition & 1 deletion src/NumHask/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module NumHask.Array
where

import NumHask.Array.Fixed
import NumHask.Array.Shape hiding (asSingleton, asScalar, concatenate, rank, reorder, size, squeeze)
import NumHask.Array.Shape hiding (asSingleton, asScalar, concatenate, rank, reorder, size, squeeze, rotate)

-- $imports
--
Expand Down
6 changes: 2 additions & 4 deletions src/NumHask/Array/Dynamic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ append ::
Array a ->
Array a ->
Array a
append d a b = insert d (S.indexOf d (shape a)) a b
append d a b = insert d (S.unsafeGetIndex d (shape a)) a b

-- | Insert along a dimension at the beginning.
--
Expand Down Expand Up @@ -957,8 +957,6 @@ slices ::
Array a
slices ps a = dimsWise slice ps a

-- * application

-- | Reduce along specified dimensions, using the supplied fold.
--
-- >>> pretty $ reduces [0] sum a
Expand Down Expand Up @@ -1610,7 +1608,7 @@ concats ::
concats ds n a = backpermute concatDims unconcatDims a
where
concatDims s = S.insertDim n (S.size $ S.takeDims ds s) (S.deleteDims ds s)
unconcatDims s = S.insertDims ds (S.shapen (S.takeDims ds (shape a)) (S.indexOf n s)) (S.deleteDim n s)
unconcatDims s = S.insertDims ds (S.shapen (S.takeDims ds (shape a)) (S.unsafeGetIndex n s)) (S.deleteDim n s)

-- | Rotate an array along a dimension.
--
Expand Down
166 changes: 140 additions & 26 deletions src/NumHask/Array/Fixed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ module NumHask.Array.Fixed
-- ** Single-dimension operators
take,
slice,
insert,

-- * Operators
takes,
Expand Down Expand Up @@ -106,7 +107,11 @@ module NumHask.Array.Fixed

-- * Shape manipulations
reshape,
reorder,
squeeze,
reverses,
rotate,
rotates,

{-
Expand Down Expand Up @@ -165,7 +170,7 @@ import Data.Vector qualified as V
import Fcf hiding (type (&&), type (+), type (-), type (++))
import GHC.TypeNats
import NumHask.Array.Dynamic qualified as D
import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze)
import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze, rotate, reorder)
import NumHask.Array.Shape qualified as S
import NumHask.Prelude as P hiding (Min, take, diff, zipWith, empty, sequence, toList, length)
import Prettyprinter hiding (dot)
Expand Down Expand Up @@ -629,13 +634,14 @@ diag a = tabulate (go . fromFins)
-- [0,1,0],
-- [0,0,2]]
undiag ::
forall a s.
forall s' a s.
( HasShape s,
Additive a,
HasShape ((++) s s)
HasShape s',
s' ~ Eval ((++) s s),
Additive a
) =>
Array s a ->
Array ((++) s s) a
Array s' a
undiag a = tabulate (go . fromFins)
where
go [] = index a (UnsafeFins [])
Expand Down Expand Up @@ -746,9 +752,43 @@ slice ::
Proxy l ->
Array s a ->
Array s' a

slice _ o _ a = unsafeBackpermute (S.modifyDim (valueOf @d) (+ o)) a

-- | Insert along a dimension at a position.
--
-- >>> pretty $ insert (Proxy :: Proxy 2) 0 a (konst @[2,3] 0)
-- [[[0,0,1,2,3],
-- [0,4,5,6,7],
-- [0,8,9,10,11]],
-- [[0,12,13,14,15],
-- [0,16,17,18,19],
-- [0,20,21,22,23]]]
-- >>> toDynamic $ insert (Proxy :: Proxy 0) 0 (toScalar 1) (toScalar 2)
-- UnsafeArray [2] [2,1]
insert ::
forall s' s si d a.
(KnownNat d,
HasShape s,
HasShape si,
HasShape s',
HasShape (Eval (AsSingleton s)),
HasShape (Eval (AsSingleton si)),
s' ~ Eval (IncAt d (Eval (AsSingleton s)))
) =>
Proxy d ->
Int ->
Array s a ->
Array si a ->
Array s' a
insert _ i a b = tabulate go
where
go xs
| xs' !! d == i = index (asSingleton b) (UnsafeFins (S.deleteDim d xs'))
| xs' !! d < i = index (asSingleton a) (UnsafeFins xs')
| otherwise = index (asSingleton a) (UnsafeFins (S.decAt d xs'))
where xs' = fromFins xs
d = valueOf @d

-- | Takes the top-most elements according to the new dimensions.
--
-- >>> pretty (takes @[1,2,2] a)
Expand Down Expand Up @@ -1114,15 +1154,16 @@ zips ds f a b = joins ds (zipWith f (extracts ds a) (extracts ds b))
-- [[([1,1],[0,0]),([1,1],[0,1])],
-- [([1,1],[1,0]),([1,1],[1,1])]]]]
expand ::
forall s s' a b c.
( HasShape s,
HasShape s',
HasShape ((++) s s')
forall sc sa sb a b c.
( HasShape sa,
HasShape sb,
HasShape sc,
sc ~ Eval ((++) sa sb)
) =>
(a -> b -> c) ->
Array s a ->
Array s' b ->
Array ((++) s s') c
Array sa a ->
Array sb b ->
Array sc c
expand f a b = tabulate (\i -> f (index a (UnsafeFins $ List.take r (fromFins i))) (index b (UnsafeFins $ drop r (fromFins i))))
where
r = rank a
Expand All @@ -1139,15 +1180,16 @@ expand f a b = tabulate (\i -> f (index a (UnsafeFins $ List.take r (fromFins i)
-- [(0,4),(1,4),(2,4)],
-- [(0,5),(1,5),(2,5)]]
expandr ::
forall s s' a b c.
( HasShape s,
HasShape s',
HasShape ((++) s s')
forall sc sa sb a b c.
( HasShape sa,
HasShape sb,
HasShape sc,
sc ~ Eval ((++) sa sb)
) =>
(a -> b -> c) ->
Array s a ->
Array s' b ->
Array ((++) s s') c
Array sa a ->
Array sb b ->
Array sc c
expandr f a b = tabulate (\i -> f (index a (UnsafeFins $ drop r (fromFins i))) (index b (UnsafeFins $ List.take r (fromFins i))))
where
r = rank a
Expand Down Expand Up @@ -1204,15 +1246,15 @@ dot ::
forall a b c d sa sb s' ss se.
( HasShape sa,
HasShape sb,
HasShape (sa ++ sb),
se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
HasShape (Eval ((++) sa sb)),
se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)),
HasShape se,
KnownNat (Eval (Minimum se)),
KnownNat (Eval (Rank sa) - 1),
KnownNat (Eval (Rank sa)),
ss ~ '[Eval (Minimum se)],
HasShape ss,
s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)),
HasShape s'
) =>
(Array ss c -> d) ->
Expand Down Expand Up @@ -1250,15 +1292,15 @@ mult ::
Multiplicative a,
HasShape sa,
HasShape sb,
HasShape (sa ++ sb),
se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
HasShape (Eval ((++) sa sb)),
se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)),
HasShape se,
KnownNat (Eval (Minimum se)),
KnownNat (Eval (Rank sa) - 1),
KnownNat (Eval (Rank sa)),
ss ~ '[Eval (Minimum se)],
HasShape ss,
s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb)),
HasShape s'
) =>
Array sa a ->
Expand Down Expand Up @@ -1294,6 +1336,29 @@ reshape = unsafeBackpermute (shapen s . flatten s')
s = shapeOf @s
s' = shapeOf @s'

-- | Change the order of dimensions.
--
-- >>> pretty $ reorder (Proxy :: Proxy [2,0,1]) a
-- [[[0,4,8],
-- [12,16,20]],
-- [[1,5,9],
-- [13,17,21]],
-- [[2,6,10],
-- [14,18,22]],
-- [[3,7,11],
-- [15,19,23]]]
reorder ::
forall dims s s' a.
(HasShape s,
HasShape s',
HasShape dims,
s' ~ Eval (Reorder s dims)
) =>
Proxy dims ->
Array s a ->
Array s' a
reorder _ a = unsafeBackpermute (\s -> S.insertDims (shapeOf @dims) s []) a

-- | Remove single dimensions.
--
-- >>> let sq = array [1..24] :: Array '[2,1,3,4,1] Int
Expand Down Expand Up @@ -1341,6 +1406,55 @@ squeeze ::
Array t a
squeeze = unsafeModifyShape

-- | Reverses element order along specified dimensions.
--
-- >>> pretty $ reverses [0,1] a
-- [[[20,21,22,23],
-- [16,17,18,19],
-- [12,13,14,15]],
-- [[8,9,10,11],
-- [4,5,6,7],
-- [0,1,2,3]]]
reverses ::
(HasShape s) =>
[Int] ->
Array s a ->
Array s a
reverses ds a = unsafeBackpermute (S.reverseIndex ds (shape a)) a

-- | Rotate an array along a dimension.
--
-- >>> pretty $ rotate 1 2 a
-- [[[8,9,10,11],
-- [0,1,2,3],
-- [4,5,6,7]],
-- [[20,21,22,23],
-- [12,13,14,15],
-- [16,17,18,19]]]
rotate ::
(HasShape s) =>
Int ->
Int ->
Array s a ->
Array s a
rotate d r a = unsafeBackpermute (S.modifyDim d (\i -> (r + i) `mod` (shape a !! d))) a

-- | Rotate an array by/along offset,dimension tuples.
--
-- >>> pretty $ rotates [(1, 2)] a
-- [[[8,9,10,11],
-- [0,1,2,3],
-- [4,5,6,7]],
-- [[20,21,22,23],
-- [12,13,14,15],
-- [16,17,18,19]]]
rotates ::
forall a s.
(HasShape s) =>
[(Int, Int)] ->
Array s a ->
Array s a
rotates rs a = unsafeBackpermute (rotateIndex rs (shapeOf @s)) a

{-
-- | Reshape an array (with the same number of elements).
Expand Down
Loading

0 comments on commit 4fc40b4

Please sign in to comment.