Skip to content

Commit

Permalink
upto modifyDim on first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyday567 committed Aug 17, 2024
1 parent c7549e2 commit 5313661
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 111 deletions.
90 changes: 55 additions & 35 deletions src/NumHask/Array/Fixed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ module NumHask.Array.Fixed
-- ** Element-level operators
zipWith,
modify,
diff,
imap,

-- ** Operator generalisers
Expand Down Expand Up @@ -106,6 +105,7 @@ module NumHask.Array.Fixed
mult,

-- * Shape manipulations
reshape,
squeeze,

{-
Expand Down Expand Up @@ -162,7 +162,7 @@ import Data.Functor.Classes
import Data.Functor.Rep
import Data.Proxy
import Data.Vector qualified as V
-- import GHC.TypeLits
import Fcf hiding (type (&&), type (+), type (-), type (++))
import GHC.TypeNats
import NumHask.Array.Dynamic qualified as D
import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze)
Expand Down Expand Up @@ -494,14 +494,14 @@ isScalar a = rank a == zero
--
-- >>> asSingleton (toScalar 4)
-- [4]
asSingleton :: (HasShape s, HasShape s', s' ~ AsSingleton s) => Array s a -> Array s' a
asSingleton :: (HasShape s, HasShape s', s' ~ Eval (AsSingleton s)) => Array s a -> Array s' a
asSingleton = unsafeModifyShape

-- | Convert arrays with shape [1] to scalars.
--
-- >>> pretty (asScalar (singleton 3))
-- 3
asScalar :: (HasShape s, HasShape s', s' ~ AsScalar s) => Array s a -> Array s' a
asScalar :: (HasShape s, HasShape s', s' ~ Eval (AsScalar s)) => Array s a -> Array s' a
asScalar = unsafeModifyShape

-- * Creation
Expand Down Expand Up @@ -610,12 +610,13 @@ singleton a = unsafeArray (V.singleton a)
-- >>> pretty $ diag (ident @[3,3])
-- [1,1,1]
diag ::
forall a s.
forall s' a s.
( HasShape s,
HasShape '[Minimum s]
HasShape s',
s' ~ '[Eval (Minimum s)]
) =>
Array s a ->
Array '[Minimum s] a
Array s' a
diag a = tabulate (go . fromFins)
where
go [] = index a (UnsafeFins [])
Expand Down Expand Up @@ -658,15 +659,6 @@ zipWith f (asVector -> a) (asVector -> b) = unsafeArray (V.zipWith f a b)
modify :: (HasShape s) => Fins s -> (a -> a) -> Array s a -> Array s a
modify ds f a = tabulate (\s -> bool id f (s == ds) (index a s))

-- | Row-wise difference an array using the supplied function with a lag.
--
-- FIXME:
-- > pretty $ diff 1 (-) (range [3,2])
-- [[2,2],
-- [2,2]]
diff :: (HasShape s, s' ~ ReplaceDim 0 (s!!0-n) s) => Proxy Int -> (a -> a -> b) -> Array s a -> Array s' b
diff _ _ _ = undefined -- zipWith f (rowWise (dimsWise drop) [n] a) (rowWise (dimsWise drop) [-n] a)

-- | Maps an index function at element-level.
--
-- >>> pretty $ imap (\xs x -> x - sum xs) a
Expand Down Expand Up @@ -726,7 +718,7 @@ take ::
KnownNat d,
KnownNat t,
-- Fin (Rank s) ~ SNat d,
ReplaceDim d (S.Min t (s !! d)) s ~ s'
Eval (SetIndex d (Eval (Min t (s !! d))) s) ~ s'
) =>
SNat d ->
SNat t ->
Expand All @@ -748,7 +740,7 @@ slice ::
(HasShape s,
HasShape s',
KnownNat d,
ReplaceDim d l s ~ s') =>
Eval (SetIndex d l s) ~ s') =>
Proxy d ->
Int ->
Proxy l ->
Expand All @@ -771,7 +763,7 @@ takes ::
forall s' s a.
( HasShape s,
HasShape s',
ShapeLTE s' s ~ 'True
Eval (ShapeLTE s' s) ~ 'True
) =>
Array s a ->
Array s' a
Expand All @@ -786,7 +778,7 @@ drops ::
forall s' s a.
( HasShape s,
HasShape s',
ShapeLTE s' s ~ 'True
Eval (ShapeLTE s' s) ~ 'True
) =>
Array s a ->
Array s' a
Expand Down Expand Up @@ -1170,14 +1162,14 @@ expandr f a b = tabulate (\i -> f (index a (UnsafeFins $ drop r (fromFins i))) (
-- [32,77]]
contract ::
forall a b s ss s' ds.
( KnownNat (Minimum (TakeDims ds s)),
( KnownNat (Eval (Minimum (TakeDims ds s))),
HasShape (TakeDims ds s),
HasShape s,
HasShape ds,
HasShape ss,
HasShape s',
s' ~ DeleteDims ds s,
ss ~ '[Minimum (TakeDims ds s)]
ss ~ '[Eval (Minimum (TakeDims ds s))]
) =>
(Array ss a -> b) ->
Proxy ds ->
Expand Down Expand Up @@ -1213,22 +1205,22 @@ dot ::
( HasShape sa,
HasShape sb,
HasShape (sa ++ sb),
se ~ TakeDims '[Rank sa - 1, Rank sa] (sa ++ sb),
se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
HasShape se,
KnownNat (Minimum se),
KnownNat (Rank sa - 1),
KnownNat (Rank sa),
ss ~ '[Minimum se],
KnownNat (Eval (Minimum se)),
KnownNat (Eval (Rank sa) - 1),
KnownNat (Eval (Rank sa)),
ss ~ '[Eval (Minimum se)],
HasShape ss,
s' ~ DeleteDims '[Rank sa - 1, Rank sa] (sa ++ sb),
s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
HasShape s'
) =>
(Array ss c -> d) ->
(a -> b -> c) ->
Array sa a ->
Array sb b ->
Array s' d
dot f g a b = contract f (Proxy :: Proxy '[Rank sa - 1, Rank sa]) (expand g a b)
dot f g a b = contract f (Proxy :: Proxy '[Eval (Rank sa) - 1, Eval (Rank sa)]) (expand g a b)

-- | Array multiplication.
--
Expand Down Expand Up @@ -1259,21 +1251,49 @@ mult ::
HasShape sa,
HasShape sb,
HasShape (sa ++ sb),
se ~ TakeDims '[Rank sa - 1, Rank sa] (sa ++ sb),
se ~ TakeDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
HasShape se,
KnownNat (Minimum se),
KnownNat (Rank sa - 1),
KnownNat (Rank sa),
ss ~ '[Minimum se],
KnownNat (Eval (Minimum se)),
KnownNat (Eval (Rank sa) - 1),
KnownNat (Eval (Rank sa)),
ss ~ '[Eval (Minimum se)],
HasShape ss,
s' ~ DeleteDims '[Rank sa - 1, Rank sa] (sa ++ sb),
s' ~ DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (sa ++ sb),
HasShape s'
) =>
Array sa a ->
Array sb a ->
Array s' a
mult = dot sum (*)

-- | Reshape an array (with the same number of elements).
--
-- >>> pretty $ reshape @[4,3,2] a
-- [[[0,1],
-- [2,3],
-- [4,5]],
-- [[6,7],
-- [8,9],
-- [10,11]],
-- [[12,13],
-- [14,15],
-- [16,17]],
-- [[18,19],
-- [20,21],
-- [22,23]]]
reshape ::
forall s' s a.
( Eval (Size s) ~ Eval (Size s'),
HasShape s,
HasShape s'
) =>
Array s a ->
Array s' a
reshape = unsafeBackpermute (shapen s . flatten s')
where
s = shapeOf @s
s' = shapeOf @s'

-- | Remove single dimensions.
--
-- >>> let sq = array [1..24] :: Array '[2,1,3,4,1] Int
Expand Down
Loading

0 comments on commit 5313661

Please sign in to comment.