diff --git a/numhask-array.cabal b/numhask-array.cabal index a84c113..672c39e 100644 --- a/numhask-array.cabal +++ b/numhask-array.cabal @@ -57,6 +57,7 @@ library , adjunctions >=4.0 && <5 , base >=4.14 && <5 , distributive >=0.4 && <0.7 + , first-class-families , numhask >=0.12 && <0.13 , prettyprinter >= 1.7 && <1.8 , random >=1.2 && <1.3 diff --git a/src/NumHask/Array/Fixed.hs b/src/NumHask/Array/Fixed.hs index e4f029b..3e59170 100644 --- a/src/NumHask/Array/Fixed.hs +++ b/src/NumHask/Array/Fixed.hs @@ -105,6 +105,9 @@ module NumHask.Array.Fixed dot, mult, + -- * Shape manipulations + squeeze, + {- reshape, @@ -162,7 +165,7 @@ import Data.Vector qualified as V -- import GHC.TypeLits import GHC.TypeNats import NumHask.Array.Dynamic qualified as D -import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton) +import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze) import NumHask.Array.Shape qualified as S import NumHask.Prelude as P hiding (Min, take, diff, zipWith, empty, sequence, toList, length) import Prettyprinter hiding (dot) @@ -225,6 +228,8 @@ import Data.List qualified as List -- | A multidimensional array with a type-level shape -- +-- >>> array @[2,3,4] [1..24::Int] +-- [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24] -- >>> array [1..24] :: Array '[2,3,4] Int -- [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24] -- >>> pretty (array [1..24] :: Array '[2,3,4] Int) @@ -240,7 +245,8 @@ import Data.List qualified as List -- -- In many spots, [TypeApplication](https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/type_applications.html) can be cleaner. -- --- >>> array @[2,3] @[Int] [1..6] +-- FIXME: +-- > array @[2,3] @[Int] [1..6] -- [1,2,3,4,5,6] -- -- >>> index a (S.UnsafeFins [1,2,3]) @@ -398,7 +404,7 @@ safeArray v = -- -- >>> array [0..22] :: Array [2,3,4] Int -- *** Exception: NumHaskException {errorMessage = "ShapeMismatch"} -array :: (HasShape s, FromVector t a) => t -> Array s a +array :: forall s a t. (HasShape s, FromVector t a) => t -> Array s a array v = fromMaybe (throw (NumHaskException "ShapeMismatch")) (safeArray v) @@ -812,7 +818,7 @@ indexes _ xs a = unsafeBackpermute (S.insertDims (shapeOf @ds) xs) a -- s :: Array '[4] Int -- -- >>> pretty $ s --- [17,18,19,20] +-- [16,17,18,19] indexesExcept :: forall ds s s' a. ( HasShape s, @@ -854,7 +860,7 @@ lasts ds a = indexes ds lastxs a -- | Select the tail elements along the supplied dimensions -- -- FIXME: Get new shape into the type level. --- >>> pretty $ tails (Proxy :: Proxy [0,2]) (Proxy :: Proxy [1,2,3]) a +-- > pretty $ tails (Proxy :: Proxy [0,2]) (Proxy :: Proxy [1,2,3]) a -- [[[13,14,15], -- [17,18,19], -- [21,22,23]]] @@ -903,7 +909,7 @@ slices _ o _ a = unsafeBackpermute (S.modifyDims (shapeOf @ds) (fmap (+) o)) a -- -- > a == (fromScalar <$> extracts [0..rank a] a) -- --- >>> pretty $ shape <$> extracts (Proxy :: Proxy [0]) a +-- >>> pretty $ shape <$> extracts (Proxy :: Proxy '[0]) a -- [[3,4],[3,4]] extracts :: forall ds st si so a. @@ -921,8 +927,8 @@ extracts d a = tabulate (\s -> indexes d (fromFins s) a) -- | Extracts /except/ dimensions to an outer layer. -- --- >>> let e = extractsExcept [1,2] a --- >>> pretty $ shape <$> extracts [0] a +-- >>> let e = extractsExcept (Proxy :: Proxy '[1,2]) a +-- >>> pretty $ shape <$> e -- [[3,4],[3,4]] extractsExcept :: forall ds st si so a. @@ -942,7 +948,7 @@ extractsExcept d a = tabulate go -- | Reduce along specified dimensions, using the supplied fold. -- --- >>> pretty $ reduces (Proxy :: Proxy [0]) sum a +-- >>> pretty $ reduces (Proxy :: Proxy '[0]) sum a -- [66,210] -- >>> pretty $ reduces (Proxy :: Proxy [0,2]) sum a -- [[12,15,18,21], @@ -983,8 +989,8 @@ traverses ds f a = joins ds <$> traverse (traverse f) (extracts ds a) -- | Join inner and outer dimension layers by supplied dimensions. No checks on shape. -- --- >>> let e = extracts [1,0] a --- >>> let j = joins [1,0] e +-- >>> let e = extracts (Proxy :: Proxy [1,0]) a +-- >>> let j = joins (Proxy :: Proxy [1,0]) e -- >>> a == j -- True joins :: @@ -1003,14 +1009,15 @@ joins _ a = tabulate go -- | Join inner and outer dimension layers in outer dimension order. -- --- >>> a == join (extracts [0,1] a) +-- FIXME: +-- > a == join (extracts (Proxy :: Proxy [0,1]) a) -- True join :: forall a si so st. (HasShape st, HasShape si, HasShape so, - PrependDims so si ~ st) => + PrependDims si so ~ st) => Array so (Array si a) -> Array st a join a = tabulate go @@ -1020,7 +1027,7 @@ join a = tabulate go -- | Maps a function along specified dimensions. -- --- >>> :t maps (transpose) (Proxy :: Proxy '[1]) a +-- > :t maps (transpose) (Proxy :: Proxy '[1]) a -- maps (transpose) (Proxy :: Proxy '[1]) a :: Array [4, 3, 2] Int maps :: forall ds st st' si si' so a b. @@ -1041,13 +1048,10 @@ maps :: Array st' b maps f d a = joins d (fmapRep f (extracts d a)) --- | Filters along specified dimensions (which are flattened). +-- | Filters along specified dimensions (which are flattened as a dynamic array). -- -- >>> pretty $ filters (Proxy :: Proxy [0,1]) (any ((==0) . (`mod` 7))) a --- [[0,1,2,3], --- [4,5,6,7], --- [12,13,14,15], --- [20,21,22,23]] +-- [[0,1,2,3],[4,5,6,7],[12,13,14,15],[20,21,22,23]] filters :: forall ds si so a. ( HasShape ds, @@ -1064,7 +1068,7 @@ filters ds p a = D.asArray $ V.filter p $ asVector (extracts ds a) -- | Zips two arrays with a function along specified dimensions. -- --- >>> pretty $ zips (Proxy :: Proxy [0,1]) (zipWith (,)) a (reverses [0] a) +-- > pretty $ zips (Proxy :: Proxy [0,1]) (zipWith (,)) a (reverses [0] a) -- [[[(0,12),(1,13),(2,14),(3,15)], -- [(4,16),(5,17),(6,18),(7,19)], -- [(8,20),(9,21),(10,22),(11,23)]], @@ -1099,7 +1103,7 @@ zips ds f a b = joins ds (zipWith f (extracts ds a) (extracts ds b)) -- -- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category. -- --- >>> x = array [3] [1,2,3] +-- >>> x = array [1,2,3] :: Array '[3] Int -- >>> pretty $ expand (*) x x -- [[1,2,3], -- [2,4,6], @@ -1107,7 +1111,7 @@ zips ds f a b = joins ds (zipWith f (extracts ds a) (extracts ds b)) -- -- Alternatively, expand can be understood as representing the permutation of element pairs of two arrays, so like the Applicative List instance. -- --- >>> i2 = indices [2,2] +-- >>> i2 = indices @[2,2] -- >>> pretty $ expand (,) i2 i2 -- [[[[([0,0],[0,0]),([0,0],[0,1])], -- [([0,0],[1,0]),([0,0],[1,1])]], @@ -1160,8 +1164,8 @@ expandr f a b = tabulate (\i -> f (index a (UnsafeFins $ drop r (fromFins i))) ( -- -- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2. -- --- >>> let b = array [2,3] [1..6] :: Array Int --- >>> pretty $ contract sum [1,2] (expand (*) b (transpose b)) +-- > let b = array [1..6] :: Array [2,3] Int +-- > pretty $ contract sum [1,2] (expand (*) b (transpose b)) -- [[14,32], -- [32,77]] contract :: @@ -1185,24 +1189,24 @@ contract f xs a = f . diag <$> extractsExcept xs a -- -- matrix multiplication -- --- >>> let b = array [2,3] [1..6] :: Array Int --- >>> pretty $ dot sum (*) b (transpose b) +-- > let b = array [1..6] :: Array [2,3] Int +-- > pretty $ dot sum (*) b (transpose b) -- [[14,32], -- [32,77]] -- -- inner product -- --- >>> let v = array [3] [1..3] :: Array Int +-- >>> let v = array [1..3] :: Array '[3] Int -- >>> pretty $ dot sum (*) v v -- 14 -- -- matrix-vector multiplication -- Note that an Array with shape [3] is neither a row vector nor column vector. -- --- >>> pretty $ dot sum (*) v b +-- > pretty $ dot sum (*) v b -- [9,12,15] -- --- >>> pretty $ dot sum (*) b v +-- > pretty $ dot sum (*) b v -- [14,32] dot :: forall a b c d sa sb s' ss se. @@ -1230,23 +1234,23 @@ dot f g a b = contract f (Proxy :: Proxy '[Rank sa - 1, Rank sa]) (expand g a b) -- -- matrix multiplication -- --- >>> let b = array [2,3] [1..6] :: Array Int --- >>> pretty $ mult b (transpose b) +-- > let b = array [1..6] :: Array [2,3] Int +-- > pretty $ mult b (transpose b) -- [[14,32], -- [32,77]] -- -- inner product -- --- >>> let v = array [3] [1..3] :: Array Int +-- >>> let v = array @'[3] [1..3::Int] -- >>> pretty $ mult v v -- 14 -- -- matrix-vector multiplication -- --- >>> pretty $ mult v b +-- > pretty $ mult v b -- [9,12,15] -- --- >>> pretty $ mult b v +-- > pretty $ mult b v -- [14,32] mult :: forall a sa sb s' ss se. @@ -1270,6 +1274,54 @@ mult :: Array s' a mult = dot sum (*) +-- | Remove single dimensions. +-- +-- >>> let sq = array [1..24] :: Array '[2,1,3,4,1] Int +-- >>> pretty sq +-- [[[[[1], +-- [2], +-- [3], +-- [4]], +-- [[5], +-- [6], +-- [7], +-- [8]], +-- [[9], +-- [10], +-- [11], +-- [12]]]], +-- [[[[13], +-- [14], +-- [15], +-- [16]], +-- [[17], +-- [18], +-- [19], +-- [20]], +-- [[21], +-- [22], +-- [23], +-- [24]]]]] +-- >>> pretty $ squeeze sq +-- [[[1,2,3,4], +-- [5,6,7,8], +-- [9,10,11,12]], +-- [[13,14,15,16], +-- [17,18,19,20], +-- [21,22,23,24]]] +-- +-- >>> pretty $ squeeze (array [1] :: Array '[1,1] Double) +-- 1.0 +squeeze :: + forall s t a. + (HasShape s, + HasShape t, + t ~ Eval (Squeeze s)) => + Array s a -> + Array t a +squeeze = unsafeModifyShape + + {- -- | Reshape an array (with the same number of elements). -- diff --git a/src/NumHask/Array/Shape.hs b/src/NumHask/Array/Shape.hs index bf59fe9..69770f8 100644 --- a/src/NumHask/Array/Shape.hs +++ b/src/NumHask/Array/Shape.hs @@ -115,12 +115,13 @@ module NumHask.Array.Shape decAt, Zip, Windows, + Fcf.Eval, ) where import Data.List qualified as List import Data.Proxy -import Data.Type.Bool +import Data.Type.Bool hiding (Not) import Data.Type.Equality import GHC.TypeLits qualified as L import Prelude qualified @@ -135,6 +136,9 @@ import GHC.TypeLits (TypeError, ErrorMessage(..)) import Text.Read import Data.Type.Ord import Unsafe.Coerce +import Fcf hiding (type (&&), type (+), type (-), type (++)) +import Fcf qualified +import Control.Monad -- $setup -- >>> :m -Prelude @@ -144,6 +148,7 @@ import Unsafe.Coerce -- >>> :set -XRebindableSyntax -- >>> import NumHask.Prelude -- >>> import NumHask.Array.Shape as S +-- >>> import Fcf (Eval) -- | Get the value of a type level Nat. -- Use with explicit type application, i.e., @valueOf \@42@ @@ -689,16 +694,21 @@ type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where ~ 'True -- | remove 1's from a list +-- +-- >>> squeeze [0,1,2,3] +-- [0,2,3] squeeze :: (Eq a, Multiplicative a) => [a] -> [a] squeeze = filter (/= one) -type family Squeeze (a :: [Nat]) where - Squeeze '[] = '[] - Squeeze a = Filter '[] a 1 +-- | Remove 1's from a list. +-- +-- >>> :k! (Eval (Squeeze [0,1,2,3])) +-- (Eval (Squeeze [0,1,2,3])) :: [Natural] +-- = [0, 2, 3] +data Squeeze :: [a] -> Exp [a] -type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where - Filter r '[] _ = Reverse r - Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i +type instance Eval (Squeeze xs) = + Eval (Filter (Not <=< TyEq 1) xs) -- | Reflect a list of Nats class KnownNats (ns :: [Nat]) where @@ -720,14 +730,6 @@ instance KnownNatss '[] where instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where natValss _ = natVals (Proxy @n) : natValss (Proxy @ns) -type family Zip (xs :: [Nat]) (ys :: [Nat]) where - Zip xs ys = ZipGo xs ys '[] - -type family ZipGo (xs :: [Nat]) (ys :: [Nat]) (zs :: [(Nat, Nat)]) where - ZipGo '[] _ zs = zs - ZipGo _ '[] zs = zs - ZipGo (x : xs) (y : ys) zs = zs -- ZipGo xs ys ((x,y):zs) - type family Windows (ws :: [Nat]) (xs :: [Nat]) where Windows ws _ = ws