Skip to content

Commit

Permalink
intro fcf
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyday567 committed Aug 16, 2024
1 parent 8168422 commit c7549e2
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 49 deletions.
1 change: 1 addition & 0 deletions numhask-array.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ library
, adjunctions >=4.0 && <5
, base >=4.14 && <5
, distributive >=0.4 && <0.7
, first-class-families
, numhask >=0.12 && <0.13
, prettyprinter >= 1.7 && <1.8
, random >=1.2 && <1.3
Expand Down
120 changes: 86 additions & 34 deletions src/NumHask/Array/Fixed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ module NumHask.Array.Fixed
dot,
mult,

-- * Shape manipulations
squeeze,

{-
reshape,
Expand Down Expand Up @@ -162,7 +165,7 @@ import Data.Vector qualified as V
-- import GHC.TypeLits
import GHC.TypeNats
import NumHask.Array.Dynamic qualified as D
import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton)
import NumHask.Array.Shape hiding (rank, size, asScalar, asSingleton, squeeze)
import NumHask.Array.Shape qualified as S
import NumHask.Prelude as P hiding (Min, take, diff, zipWith, empty, sequence, toList, length)
import Prettyprinter hiding (dot)
Expand Down Expand Up @@ -225,6 +228,8 @@ import Data.List qualified as List

-- | A multidimensional array with a type-level shape
--
-- >>> array @[2,3,4] [1..24::Int]
-- [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
-- >>> array [1..24] :: Array '[2,3,4] Int
-- [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
-- >>> pretty (array [1..24] :: Array '[2,3,4] Int)
Expand All @@ -240,7 +245,8 @@ import Data.List qualified as List
--
-- In many spots, [TypeApplication](https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/type_applications.html) can be cleaner.
--
-- >>> array @[2,3] @[Int] [1..6]
-- FIXME:
-- > array @[2,3] @[Int] [1..6]
-- [1,2,3,4,5,6]
--
-- >>> index a (S.UnsafeFins [1,2,3])
Expand Down Expand Up @@ -398,7 +404,7 @@ safeArray v =
--
-- >>> array [0..22] :: Array [2,3,4] Int
-- *** Exception: NumHaskException {errorMessage = "ShapeMismatch"}
array :: (HasShape s, FromVector t a) => t -> Array s a
array :: forall s a t. (HasShape s, FromVector t a) => t -> Array s a
array v =
fromMaybe (throw (NumHaskException "ShapeMismatch")) (safeArray v)

Expand Down Expand Up @@ -812,7 +818,7 @@ indexes _ xs a = unsafeBackpermute (S.insertDims (shapeOf @ds) xs) a
-- s :: Array '[4] Int
--
-- >>> pretty $ s
-- [17,18,19,20]
-- [16,17,18,19]
indexesExcept ::
forall ds s s' a.
( HasShape s,
Expand Down Expand Up @@ -854,7 +860,7 @@ lasts ds a = indexes ds lastxs a
-- | Select the tail elements along the supplied dimensions
--
-- FIXME: Get new shape into the type level.
-- >>> pretty $ tails (Proxy :: Proxy [0,2]) (Proxy :: Proxy [1,2,3]) a
-- > pretty $ tails (Proxy :: Proxy [0,2]) (Proxy :: Proxy [1,2,3]) a
-- [[[13,14,15],
-- [17,18,19],
-- [21,22,23]]]
Expand Down Expand Up @@ -903,7 +909,7 @@ slices _ o _ a = unsafeBackpermute (S.modifyDims (shapeOf @ds) (fmap (+) o)) a
--
-- > a == (fromScalar <$> extracts [0..rank a] a)
--
-- >>> pretty $ shape <$> extracts (Proxy :: Proxy [0]) a
-- >>> pretty $ shape <$> extracts (Proxy :: Proxy '[0]) a
-- [[3,4],[3,4]]
extracts ::
forall ds st si so a.
Expand All @@ -921,8 +927,8 @@ extracts d a = tabulate (\s -> indexes d (fromFins s) a)

-- | Extracts /except/ dimensions to an outer layer.
--
-- >>> let e = extractsExcept [1,2] a
-- >>> pretty $ shape <$> extracts [0] a
-- >>> let e = extractsExcept (Proxy :: Proxy '[1,2]) a
-- >>> pretty $ shape <$> e
-- [[3,4],[3,4]]
extractsExcept ::
forall ds st si so a.
Expand All @@ -942,7 +948,7 @@ extractsExcept d a = tabulate go

-- | Reduce along specified dimensions, using the supplied fold.
--
-- >>> pretty $ reduces (Proxy :: Proxy [0]) sum a
-- >>> pretty $ reduces (Proxy :: Proxy '[0]) sum a
-- [66,210]
-- >>> pretty $ reduces (Proxy :: Proxy [0,2]) sum a
-- [[12,15,18,21],
Expand Down Expand Up @@ -983,8 +989,8 @@ traverses ds f a = joins ds <$> traverse (traverse f) (extracts ds a)

-- | Join inner and outer dimension layers by supplied dimensions. No checks on shape.
--
-- >>> let e = extracts [1,0] a
-- >>> let j = joins [1,0] e
-- >>> let e = extracts (Proxy :: Proxy [1,0]) a
-- >>> let j = joins (Proxy :: Proxy [1,0]) e
-- >>> a == j
-- True
joins ::
Expand All @@ -1003,14 +1009,15 @@ joins _ a = tabulate go

-- | Join inner and outer dimension layers in outer dimension order.
--
-- >>> a == join (extracts [0,1] a)
-- FIXME:
-- > a == join (extracts (Proxy :: Proxy [0,1]) a)
-- True
join ::
forall a si so st.
(HasShape st,
HasShape si,
HasShape so,
PrependDims so si ~ st) =>
PrependDims si so ~ st) =>
Array so (Array si a) ->
Array st a
join a = tabulate go
Expand All @@ -1020,7 +1027,7 @@ join a = tabulate go

-- | Maps a function along specified dimensions.
--
-- >>> :t maps (transpose) (Proxy :: Proxy '[1]) a
-- > :t maps (transpose) (Proxy :: Proxy '[1]) a
-- maps (transpose) (Proxy :: Proxy '[1]) a :: Array [4, 3, 2] Int
maps ::
forall ds st st' si si' so a b.
Expand All @@ -1041,13 +1048,10 @@ maps ::
Array st' b
maps f d a = joins d (fmapRep f (extracts d a))

-- | Filters along specified dimensions (which are flattened).
-- | Filters along specified dimensions (which are flattened as a dynamic array).
--
-- >>> pretty $ filters (Proxy :: Proxy [0,1]) (any ((==0) . (`mod` 7))) a
-- [[0,1,2,3],
-- [4,5,6,7],
-- [12,13,14,15],
-- [20,21,22,23]]
-- [[0,1,2,3],[4,5,6,7],[12,13,14,15],[20,21,22,23]]
filters ::
forall ds si so a.
( HasShape ds,
Expand All @@ -1064,7 +1068,7 @@ filters ds p a = D.asArray $ V.filter p $ asVector (extracts ds a)

-- | Zips two arrays with a function along specified dimensions.
--
-- >>> pretty $ zips (Proxy :: Proxy [0,1]) (zipWith (,)) a (reverses [0] a)
-- > pretty $ zips (Proxy :: Proxy [0,1]) (zipWith (,)) a (reverses [0] a)
-- [[[(0,12),(1,13),(2,14),(3,15)],
-- [(4,16),(5,17),(6,18),(7,19)],
-- [(8,20),(9,21),(10,22),(11,23)]],
Expand Down Expand Up @@ -1099,15 +1103,15 @@ zips ds f a b = joins ds (zipWith f (extracts ds a) (extracts ds b))
--
-- ... the tensor product can be extended to other categories of mathematical objects in addition to vector spaces, such as to matrices, tensors, algebras, topological vector spaces, and modules. In each such case the tensor product is characterized by a similar universal property: it is the freest bilinear operation. The general concept of a "tensor product" is captured by monoidal categories; that is, the class of all things that have a tensor product is a monoidal category.
--
-- >>> x = array [3] [1,2,3]
-- >>> x = array [1,2,3] :: Array '[3] Int
-- >>> pretty $ expand (*) x x
-- [[1,2,3],
-- [2,4,6],
-- [3,6,9]]
--
-- Alternatively, expand can be understood as representing the permutation of element pairs of two arrays, so like the Applicative List instance.
--
-- >>> i2 = indices [2,2]
-- >>> i2 = indices @[2,2]
-- >>> pretty $ expand (,) i2 i2
-- [[[[([0,0],[0,0]),([0,0],[0,1])],
-- [([0,0],[1,0]),([0,0],[1,1])]],
Expand Down Expand Up @@ -1160,8 +1164,8 @@ expandr f a b = tabulate (\i -> f (index a (UnsafeFins $ drop r (fromFins i))) (
--
-- This generalises a tensor contraction by allowing the number of contracting diagonals to be other than 2.
--
-- >>> let b = array [2,3] [1..6] :: Array Int
-- >>> pretty $ contract sum [1,2] (expand (*) b (transpose b))
-- > let b = array [1..6] :: Array [2,3] Int
-- > pretty $ contract sum [1,2] (expand (*) b (transpose b))
-- [[14,32],
-- [32,77]]
contract ::
Expand All @@ -1185,24 +1189,24 @@ contract f xs a = f . diag <$> extractsExcept xs a
--
-- matrix multiplication
--
-- >>> let b = array [2,3] [1..6] :: Array Int
-- >>> pretty $ dot sum (*) b (transpose b)
-- > let b = array [1..6] :: Array [2,3] Int
-- > pretty $ dot sum (*) b (transpose b)
-- [[14,32],
-- [32,77]]
--
-- inner product
--
-- >>> let v = array [3] [1..3] :: Array Int
-- >>> let v = array [1..3] :: Array '[3] Int
-- >>> pretty $ dot sum (*) v v
-- 14
--
-- matrix-vector multiplication
-- Note that an Array with shape [3] is neither a row vector nor column vector.
--
-- >>> pretty $ dot sum (*) v b
-- > pretty $ dot sum (*) v b
-- [9,12,15]
--
-- >>> pretty $ dot sum (*) b v
-- > pretty $ dot sum (*) b v
-- [14,32]
dot ::
forall a b c d sa sb s' ss se.
Expand Down Expand Up @@ -1230,23 +1234,23 @@ dot f g a b = contract f (Proxy :: Proxy '[Rank sa - 1, Rank sa]) (expand g a b)
--
-- matrix multiplication
--
-- >>> let b = array [2,3] [1..6] :: Array Int
-- >>> pretty $ mult b (transpose b)
-- > let b = array [1..6] :: Array [2,3] Int
-- > pretty $ mult b (transpose b)
-- [[14,32],
-- [32,77]]
--
-- inner product
--
-- >>> let v = array [3] [1..3] :: Array Int
-- >>> let v = array @'[3] [1..3::Int]
-- >>> pretty $ mult v v
-- 14
--
-- matrix-vector multiplication
--
-- >>> pretty $ mult v b
-- > pretty $ mult v b
-- [9,12,15]
--
-- >>> pretty $ mult b v
-- > pretty $ mult b v
-- [14,32]
mult ::
forall a sa sb s' ss se.
Expand All @@ -1270,6 +1274,54 @@ mult ::
Array s' a
mult = dot sum (*)

-- | Remove single dimensions.
--
-- >>> let sq = array [1..24] :: Array '[2,1,3,4,1] Int
-- >>> pretty sq
-- [[[[[1],
-- [2],
-- [3],
-- [4]],
-- [[5],
-- [6],
-- [7],
-- [8]],
-- [[9],
-- [10],
-- [11],
-- [12]]]],
-- [[[[13],
-- [14],
-- [15],
-- [16]],
-- [[17],
-- [18],
-- [19],
-- [20]],
-- [[21],
-- [22],
-- [23],
-- [24]]]]]
-- >>> pretty $ squeeze sq
-- [[[1,2,3,4],
-- [5,6,7,8],
-- [9,10,11,12]],
-- [[13,14,15,16],
-- [17,18,19,20],
-- [21,22,23,24]]]
--
-- >>> pretty $ squeeze (array [1] :: Array '[1,1] Double)
-- 1.0
squeeze ::
forall s t a.
(HasShape s,
HasShape t,
t ~ Eval (Squeeze s)) =>
Array s a ->
Array t a
squeeze = unsafeModifyShape


{-
-- | Reshape an array (with the same number of elements).
--
Expand Down
32 changes: 17 additions & 15 deletions src/NumHask/Array/Shape.hs
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ module NumHask.Array.Shape
decAt,
Zip,
Windows,
Fcf.Eval,
)
where

import Data.List qualified as List
import Data.Proxy
import Data.Type.Bool
import Data.Type.Bool hiding (Not)
import Data.Type.Equality
import GHC.TypeLits qualified as L
import Prelude qualified
Expand All @@ -135,6 +136,9 @@ import GHC.TypeLits (TypeError, ErrorMessage(..))
import Text.Read
import Data.Type.Ord
import Unsafe.Coerce
import Fcf hiding (type (&&), type (+), type (-), type (++))
import Fcf qualified
import Control.Monad

-- $setup
-- >>> :m -Prelude
Expand All @@ -144,6 +148,7 @@ import Unsafe.Coerce
-- >>> :set -XRebindableSyntax
-- >>> import NumHask.Prelude
-- >>> import NumHask.Array.Shape as S
-- >>> import Fcf (Eval)

-- | Get the value of a type level Nat.
-- Use with explicit type application, i.e., @valueOf \@42@
Expand Down Expand Up @@ -689,16 +694,21 @@ type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where
~ 'True

-- | remove 1's from a list
--
-- >>> squeeze [0,1,2,3]
-- [0,2,3]
squeeze :: (Eq a, Multiplicative a) => [a] -> [a]
squeeze = filter (/= one)

type family Squeeze (a :: [Nat]) where
Squeeze '[] = '[]
Squeeze a = Filter '[] a 1
-- | Remove 1's from a list.
--
-- >>> :k! (Eval (Squeeze [0,1,2,3]))
-- (Eval (Squeeze [0,1,2,3])) :: [Natural]
-- = [0, 2, 3]
data Squeeze :: [a] -> Exp [a]

type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where
Filter r '[] _ = Reverse r
Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i
type instance Eval (Squeeze xs) =
Eval (Filter (Not <=< TyEq 1) xs)

-- | Reflect a list of Nats
class KnownNats (ns :: [Nat]) where
Expand All @@ -720,14 +730,6 @@ instance KnownNatss '[] where
instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where
natValss _ = natVals (Proxy @n) : natValss (Proxy @ns)

type family Zip (xs :: [Nat]) (ys :: [Nat]) where
Zip xs ys = ZipGo xs ys '[]

type family ZipGo (xs :: [Nat]) (ys :: [Nat]) (zs :: [(Nat, Nat)]) where
ZipGo '[] _ zs = zs
ZipGo _ '[] zs = zs
ZipGo (x : xs) (y : ys) zs = zs -- ZipGo xs ys ((x,y):zs)

type family Windows (ws :: [Nat]) (xs :: [Nat]) where
Windows ws _ = ws

Expand Down

0 comments on commit c7549e2

Please sign in to comment.