diff --git a/src/NumHask/Array/Fixed.hs b/src/NumHask/Array/Fixed.hs index 0ea2128..00c6b1b 100644 --- a/src/NumHask/Array/Fixed.hs +++ b/src/NumHask/Array/Fixed.hs @@ -317,6 +317,7 @@ instance (HasShape s) => Data.Distributive.Distributive (Array s) where + distribute :: (HasShape s, Functor f) => f (Array s a) -> Array s (f a) distribute = distributeRep {-# INLINE distribute #-} @@ -1176,7 +1177,7 @@ indexesT _ a = unsafeBackpermute (S.insertDims (List.zip (shapeOf @ds) (shapeOf -- | Select an index /except/ along specified dimensions. -- --- >>> let s = indexesExcept (Proxy :: Proxy '[2]) [1,1] a +-- >>> let s = indexesExcept (S.SNats @'[2]) [1,1] a -- >>> :t s -- s :: Array '[4] Int -- @@ -1187,13 +1188,14 @@ indexesExcept :: ( HasShape s, HasShape ds, HasShape s', + KnownNats ds, s' ~ Eval (TakeDims ds s) ) => - Proxy ds -> + SNats ds -> [Int] -> Array s a -> Array s' a -indexesExcept _ i a = unsafeBackpermute (\s -> insertDims (List.zip (shapeOf @ds) s) i) a +indexesExcept ds i a = unsafeBackpermute (\s -> insertDims (List.zip (Prelude.fromIntegral <$> natVals ds) s) i) a -- | Select the first element along the supplied dimensions -- @@ -1321,7 +1323,7 @@ extracts d a = tabulate (\s -> indexes d (fromFins s) a) -- | Extracts /except/ dimensions to an outer layer. -- --- >>> let e = extractsExcept (Proxy :: Proxy '[1,2]) a +-- >>> let e = extractsExcept (S.SNats @[1,2]) a -- >>> pretty $ shape <$> e -- [[3,4],[3,4]] extractsExcept :: @@ -1330,15 +1332,16 @@ extractsExcept :: HasShape ds, HasShape si, HasShape so, + KnownNats ds, so ~ Eval (DeleteDims ds st), si ~ Eval (TakeDims ds st) ) => - Proxy ds -> + SNats ds -> Array st a -> Array so (Array si a) -extractsExcept d a = tabulate go +extractsExcept ds a = tabulate go where - go s = indexesExcept d (fromFins s) a + go s = indexesExcept ds (fromFins s) a -- | Reduce along specified dimensions, using the supplied fold. -- @@ -1628,10 +1631,11 @@ contract :: HasShape ss, HasShape s', s' ~ Eval (DeleteDims ds s), - ss ~ '[Eval (Minimum (Eval (TakeDims ds s)))] + ss ~ '[Eval (Minimum (Eval (TakeDims ds s)))], + KnownNats ds ) => (Array ss a -> b) -> - Proxy ds -> + SNats ds -> Array s a -> Array s' b contract f xs a = f . diag <$> extractsExcept xs a @@ -1660,7 +1664,7 @@ contract f xs a = f . diag <$> extractsExcept xs a -- > pretty $ dot sum (*) b v -- [14,32] dot :: - forall a b c d sa sb s' ss se. + forall a b c d sa sb s' ss se x. ( HasShape sa, HasShape sb, HasShape (Eval ((++) sa sb)), @@ -1671,15 +1675,17 @@ dot :: KnownNat (Eval (Rank sa)), ss ~ '[Eval (Minimum se)], HasShape ss, - s' ~ Eval (DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb))), - HasShape s' + s' ~ Eval (DeleteDims x (Eval ((++) sa sb))), + HasShape s', + KnownNats x, + x ~ '[Eval (Rank sa) - 1, Eval (Rank sa)] ) => (Array ss c -> d) -> (a -> b -> c) -> Array sa a -> Array sb b -> Array s' d -dot f g a b = contract f (Proxy :: Proxy '[Eval (Rank sa) - 1, Eval (Rank sa)]) (expand g a b) +dot f g a b = contract f (SNats :: SNats x) (expand g a b) -- | Array multiplication. -- @@ -1704,7 +1710,7 @@ dot f g a b = contract f (Proxy :: Proxy '[Eval (Rank sa) - 1, Eval (Rank sa)]) -- > pretty $ mult b v -- [14,32] mult :: - forall a sa sb s' ss se. + forall a sa sb s' ss se x. ( Additive a, Multiplicative a, HasShape sa, @@ -1717,8 +1723,11 @@ mult :: KnownNat (Eval (Rank sa)), ss ~ '[Eval (Minimum se)], HasShape ss, - s' ~ Eval (DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb))), - HasShape s' + s' ~ Eval (DeleteDims x (Eval ((++) sa sb))), + HasShape s', + KnownNats x, + x ~ '[Eval (Rank sa) - 1, Eval (Rank sa)] + ) => Array sa a -> Array sb a -> @@ -2408,7 +2417,8 @@ instance P.Distributive a, Subtractive a, KnownNat m, - HasShape '[m, m] + HasShape '[m, m], + KnownNats '[1,2] ) => Multiplicative (Matrix m m a) where @@ -2423,7 +2433,8 @@ instance Eq a, ExpField a, KnownNat m, - HasShape '[m, m] + HasShape '[m, m], + KnownNats '[1,2] ) => Divisive (Matrix m m a) where @@ -2637,7 +2648,7 @@ uniform g r = do -- [2.1111111111111107,-0.5555555555555555,0.1111111111111111]] -- -- > D.mult (D.inverse a) a == a -inverse :: (Eq a, ExpField a, KnownNat m) => Matrix m m a -> Matrix m m a +inverse :: (Eq a, ExpField a, KnownNat m, KnownNats [1,2]) => Matrix m m a -> Matrix m m a inverse a = mult (invtri (transpose (chol a))) (invtri (chol a)) -- | [Inversion of a Triangular Matrix](https://math.stackexchange.com/questions/1003801/inverse-of-an-invertible-upper-triangular-matrix-of-order-3) @@ -2649,7 +2660,7 @@ inverse a = mult (invtri (transpose (chol a))) (invtri (chol a)) -- [0.0,0.0,1.0]] -- >>> ident == mult t (invtri t) -- True -invtri :: forall a n. (KnownNat n, ExpField a, Eq a) => Array '[n, n] a -> Array '[n, n] a +invtri :: forall a n. (KnownNat n, KnownNats [1,2], ExpField a, Eq a) => Array '[n, n] a -> Array '[n, n] a invtri a = sum (fmap (l ^) (iota @n)) * ti where ti = undiag (fmap recip (diag a)) diff --git a/src/NumHask/Array/Shape.hs b/src/NumHask/Array/Shape.hs index f2e4e6f..7a31197 100644 --- a/src/NumHask/Array/Shape.hs +++ b/src/NumHask/Array/Shape.hs @@ -34,8 +34,12 @@ module NumHask.Array.Shape withSomeNat, valueOf, int, - Shape (..), - HasShape (..), + SNats (..), + pattern SNats, + fromSNats, + KnownNats (..), + natVals, + HasShape, shapeOf, rankOf, sizeOf, @@ -193,6 +197,52 @@ valueOf = Prelude.fromIntegral $ natVal (Proxy :: Proxy n) int :: SNat n -> Int int = Prelude.fromIntegral . fromSNat +-- | Mimics SNat from GHC.TypeNats +newtype SNats (ns :: [Nat]) = UnsafeSNats [Nat] + +instance (KnownNats ns) => Show (SNats ns) + where + show s = "SNats @" <> bool "" "'" (length (natVals s) < 2) <> "[" <> mconcat (List.intersperse ", " (show <$> (natVals s))) <> "]" + +type role SNats nominal + +pattern SNats :: forall ns. () => KnownNats ns => SNats ns +pattern SNats <- (knownNatsInstance -> KnownNatsInstance) + where SNats = natsSing + +fromSNats :: SNats s -> [Nat] +fromSNats (UnsafeSNats s) = s + +-- An internal data type that is only used for defining the SNat pattern +-- synonym. +data KnownNatsInstance (ns :: [Nat]) where + KnownNatsInstance :: KnownNats ns => KnownNatsInstance ns + +-- An internal function that is only used for defining the SNat pattern +-- synonym. +knownNatsInstance :: SNats ns -> KnownNatsInstance ns +knownNatsInstance dims = withKnownNats dims KnownNatsInstance + +-- | Reflect a list of Nats +class KnownNats (ns :: [Nat]) where + natsSing :: SNats ns + +instance KnownNats '[] where + natsSing = UnsafeSNats [] + +instance (KnownNat n, KnownNats s) => KnownNats (n ': s) + where + natsSing = UnsafeSNats (fromSNat (SNat :: SNat n) : fromSNats (SNats :: SNats s)) + +natVals :: forall ns proxy. KnownNats ns => proxy ns -> [Nat] +natVals _ = case natsSing :: SNats ns of + UnsafeSNats xs -> xs + +withKnownNats :: forall ns rep (r :: TYPE rep). + SNats ns -> (KnownNats ns => r) -> r +withKnownNats = withDict @(KnownNats ns) + +{- -- | The Shape type holds a [Nat] at type level and the equivalent [Int] at value level. -- -- >>> toShape @[2,3,4] @@ -211,13 +261,16 @@ instance HasShape '[] where instance (KnownNat n, HasShape s) => HasShape (n : s) where toShape = Shape $ Prelude.fromIntegral (natVal (Proxy :: Proxy n)) : shapeVal (toShape :: Shape s) +-} + +type HasShape = KnownNats --- | Supply the value-level of a 'HasShape' +-- | Supply the value-level of a 'HasShape' as an [Int] -- -- >>> shapeOf @[2,3,4] -- [2,3,4] shapeOf :: forall s. (HasShape s) => [Int] -shapeOf = shapeVal (toShape @s) +shapeOf = Prelude.fromIntegral <$> natVals (Proxy :: Proxy s) {-# INLINE shapeOf #-} -- | The rank of a 'Shape'. @@ -225,7 +278,7 @@ shapeOf = shapeVal (toShape @s) -- >>> rankOf @[2,3,4] -- 3 rankOf :: forall s. (HasShape s) => Int -rankOf = length (shapeVal (toShape @s)) +rankOf = length (shapeOf @s) {-# INLINE rankOf #-} -- | The size of a 'Shape'. @@ -233,7 +286,7 @@ rankOf = length (shapeVal (toShape @s)) -- >>> sizeOf @[2,3,4] -- 24 sizeOf :: forall s. (HasShape s) => Int -sizeOf = product (shapeVal (toShape @s)) +sizeOf = product (shapeOf @s) {-# INLINE sizeOf #-} -- | Fin most often represents a (finite) zer-based index for a single dimension (of a multi-dimensioned hyper-rectangular array).