Skip to content

Commit

Permalink
Work on ModW32, ModW16, ModW8
Browse files Browse the repository at this point in the history
  • Loading branch information
DaveBarton committed Jul 5, 2024
1 parent f8c3c87 commit e8fc942
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 31 deletions.
8 changes: 4 additions & 4 deletions calculi.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ source-repository head

common deps
default-language: GHC2021
default-extensions: DuplicateRecordFields LambdaCase MonoLocalBinds NegativeLiterals
NoFieldSelectors OverloadedRecordDot OverloadedStrings PatternSynonyms
RecordWildCards
default-extensions: DerivingStrategies DuplicateRecordFields LambdaCase MonoLocalBinds
NegativeLiterals NoFieldSelectors OverloadedRecordDot OverloadedStrings
PatternSynonyms RecordWildCards
other-extensions: CPP DataKinds FunctionalDependencies Strict ViewPatterns
ghc-options: -Wall -Wcompat -feager-blackholing -threaded
if impl(ghc >= 9.8.1)
Expand All @@ -38,6 +38,7 @@ common deps
extra >= 1.6 && < 1.8,
fmt >= 0.4 && < 0.7,
megaparsec < 9.7,
mod >= 0.1.2 && < 0.3,
parser-combinators < 1.4,
safe < 0.4,
strict >= 0.4 && < 0.6,
Expand Down Expand Up @@ -69,7 +70,6 @@ library
ghc-trace-events < 0.2,
-- hashable, unordered-containers,
-- inspection-testing >= 0.3 && < 0.6,
mod >= 0.1.2 && < 0.3,
-- poly >= 0.5.1, deepseq, finite-typelits, vector-sized,
primitive >= 0.8 && < 0.10,
random >= 1.2.1 && < 1.3,
Expand Down
54 changes: 52 additions & 2 deletions src/Math/Algebra/Commutative/Field/ZModPW.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE Strict #-}
{-# LANGUAGE DataKinds, Strict #-}

{- | The field of integers mod @p@, for a prime @p@ that fits in a 'Word'. -}

Expand All @@ -10,8 +10,11 @@ module Math.Algebra.Commutative.Field.ZModPW (
import Math.Algebra.General.Algebra

import Data.Mod.Word (Mod, unMod)
import Data.Primitive.Types (Prim)
import Data.Proxy (Proxy(Proxy))
import GHC.TypeNats (KnownNat, SomeNat(SomeNat), natVal, someNatVal)
import Data.Word (Word8, Word16, Word32)
import GHC.TypeNats (KnownNat, Nat, SomeNat(SomeNat), natVal, someNatVal)
import Unsafe.Coerce (unsafeCoerce)


zzModPW :: forall p. KnownNat p => (Field (Mod p), Mod p -> Integer)
Expand All @@ -23,3 +26,50 @@ zzModPW = (field numAG (*) 1 fromInteger recip, balRep)
balRep x =
let u = unMod x
in toInteger (fromIntegral (if u > maxBalRep then u - p else u) :: Int)


-- types to save space from a (Mod m), e.g. in a PrimArray:

toWord :: Mod m -> Word
-- @@ move to Data.Mod.Word
toWord = unsafeCoerce

unsafeFromWord :: Word -> Mod m
-- @@ move to Data.Mod.Word
unsafeFromWord = unsafeCoerce

modW32 :: Word32 -> Mod m
modW32 = unsafeFromWord . fromIntegral

modW16 :: Word16 -> Mod m
modW16 = unsafeFromWord . fromIntegral

modW8 :: Word8 -> Mod m
modW8 = unsafeFromWord . fromIntegral

newtype ModWord32 (m :: Nat) = ModW32 { unModW32 :: Word32 }
deriving newtype (Eq, Show, Prim)

{- @@@@ need instance Num, better doc, ModWord16, ModWord8
instance KnownNat m => Num (ModWord32 m) where
ModW32 x + ModW32 y = ModW32 $ modW32 @m x + modW32 @m y
{-# INLINE (+) #-}
ModW32 x - ModW32 y = ModW32 $ modW32 @m x - modW32 @m y
{-# INLINE (-) #-}
ModW32 x * ModW32 y = ModW32 $ modW32 @m x * modW32 @m y
{-# INLINE (*) #-}
-- @@@@:
negate (ModW32 x) = Mod $ negateMod (natVal mx) x
{-# INLINE negate #-}
abs = id
{-# INLINE abs #-}
signum = const x
where
x = if natVal x > 1 then Mod 1 else Mod 0
{-# INLINE signum #-}
fromInteger x = mx
where
mx = Mod $ fromIntegerMod (natVal mx) x
{-# INLINE fromInteger #-}
-}
3 changes: 2 additions & 1 deletion src/Math/Algebra/General/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
a type class, because a single type may admit more than one structure as a given type of
algebra. (For example, consider quotient algebras such as @ℤ/pℤ@ or @R[X, Y]/(f, g)@ for
various dynamically computed @p@, @f@, and @g@.) Also, treating algebras as first-class
values allows us to construct them at arbitrary times in arbitrary ways.
values allows us to construct them at arbitrary times in arbitrary ways. On the other hand,
type classes are necessary if inlining is desired, e.g. for primitive types.
Note that a set of constructive (e.g. Haskell) values is an attempt to represent a set of
abstract (mathematical) values. The abstract values have an implicit notion of (abstract)
Expand Down
22 changes: 15 additions & 7 deletions src/Math/Algebra/Linear/SparseVector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ module Math.Algebra.Linear.SparseVector (
-- * Addition
unionWith, plusU, mkAG,
-- * Multiplication
dotWith, timesNzdC, timesC, monicizeU,
dotWith, timesNzdC, timesNzdCU, timesC, monicizeUnit,
-- * I/O
showPrec
) where
Expand Down Expand Up @@ -620,6 +620,7 @@ aPlusU bs0 nzs0 bs1 nzs1 = runST $ do
go bsAll bsAll 0 0 0
{-# SPECIALIZE aPlusU :: KnownNat m => Word64 -> PrimArray (Mod m) ->
Word64 -> PrimArray (Mod m) -> Word64 :!: PrimArray (Mod m) #-}
{-# INLINABLE aPlusU #-}

plusU :: (Eq c, Num c, Prim c) => Op2 (VectorU c)
{- ^ \(m + n\) steps, or more precisely for sparse vectors, \(k + b t\) where \(k\) and \(t\)
Expand All @@ -644,6 +645,7 @@ plusU = go
~v0 = go (C.index nzts 0) v
goGT (SVE {}) _ = undefined
{-# SPECIALIZE plusU :: KnownNat m => Op2 (VectorU (Mod m)) #-}
{-# INLINABLE plusU #-}

mkAG :: (C.Contiguous arr, C.Element arr c) =>
AbelianGroup c -> AbelianGroup (VectorA arr c)
Expand Down Expand Up @@ -700,20 +702,26 @@ timesNzdC :: (C.Contiguous arr, C.Element arr c) => Ring c -> c -> Op1 (Ve
timesNzdC (Ring { times }) c = mapNzFC (`times` c)
{-# INLINE timesNzdC #-}

timesNzdCU :: (Num c, Prim c) => c -> Op1 (VectorU c)
{- ^ 'timesNzdC' over an unboxed type, for speed. -}
timesNzdCU c (SVE bs nzs) = SVE bs (C.map' (* c) nzs)
timesNzdCU c (SVV bs iW2 nzts) = SVV bs iW2 (C.map' (timesNzdCU c) nzts)
{-# INLINABLE timesNzdCU #-}

timesC :: (C.Contiguous arr, C.Element arr c) => Ring c -> c -> Op1 (VectorA arr c)
{- ^ If the @c@ is not a right zero divisor, then 'timesNzdC' is faster. Usually \(m\) steps, or
up to \((d - log_{64} m) m\) for a very sparse vector. -}
timesC cR@(Ring { times }) c = mapC cR.isZero (`times` c)
{-# INLINE timesC #-}

monicizeU :: (C.Contiguous arr, C.Element arr c) => Ring c -> Int -> Op1 (VectorA arr c)
{- ^ @(monicizeU cR i v)@ requires that the @i@'th coefficient of @v@ is a unit. Usually \(m\)
steps, or up to \((d - log_{64} m) m\) for a very sparse vector, but checks first whether
@v@ is already monic. -}
monicizeU cR@(Ring { times }) i v =
monicizeUnit :: (C.Contiguous arr, C.Element arr c) => Ring c -> Int -> Op1 (VectorA arr c)
{- ^ @(monicizeUnit cR i v)@ requires that the @i@'th coefficient of @v@ is a unit. Usually
\(m\) steps, or up to \((d - log_{64} m) m\) for a very sparse vector, but checks first
whether @v@ is already monic. -}
monicizeUnit cR@(Ring { times }) i v =
let c = index cR.zero v i -- check for c = 1 for speed
in if rIsOne cR c then v else mapNzFC (`times` rInv cR c) v
{-# SPECIALIZE monicizeU :: Ring c -> Int -> Op1 (Vector c) #-}
{-# SPECIALIZE monicizeUnit :: Ring c -> Int -> Op1 (Vector c) #-}

-- * I/O

Expand Down
24 changes: 12 additions & 12 deletions test/Math/Algebra/Linear/TestSparseVector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ testOps cAG sumRange iTA cTA = TestOps tSP tCheck gen vAG.eq
gen = sumL' vAG <$> Gen.list sumRange (liftA2 iCToV iTA.gen cTA.gen)

type V = SV.Vector Integer -- the main type for testing SparseVector.hs
type U = Int -- V maps almost injectively to U
type Y = Int -- V maps almost injectively to Y
type VL = [Int :!: Integer] -- only DistinctAscNzs; V->VL->V == id, so VL->V is a
-- surjection
-- type IM = IM.IntMap -- only nonzero terms; V->IM->V == id
Expand All @@ -68,15 +68,15 @@ tests = testGroup "SparseVector" testsL
cTA = zzTestOps { gen = zzExpGen 200 }
vTA = testOps cAG (Range.linear 0 20) iTA cTA
vAG = SV.mkAG cAG
vToU :: V -> U
iCToU i c = (3 * i `rem` 101 + 5) * fromIntegral c
vToU = SV.foldBIMap' (+) 0 iCToU
vToY :: V -> Y
iCToY i c = (3 * i `rem` 101 + 5) * fromIntegral c
vToY = SV.foldBIMap' (+) 0 iCToY

vToIM = IM.fromDistinctAscList . map toLazy . SV.toDistinctAscNzs
imNzsToV = SV.fromDistinctAscNzs . map toStrict . IM.toAscList

testViaU :: V -> U -> TestM () -- tAnnotate v, and check it maps to u
testViaU = tImageEq vTA (===) vToU
testViaY :: V -> Y -> TestM () -- tAnnotate v, and check it maps to u
testViaY = tImageEq vTA (===) vToY
testViaL :: TestRel b -> (V -> b) -> (VL -> b) -> TestM () -- test the (V -> b)
testViaL bTestEq f okF = sameFun1TR vTA bTestEq f (okF . SV.toDistinctAscNzs)
testEqToVL :: V -> VL -> TestM ()
Expand All @@ -85,9 +85,9 @@ tests = testGroup "SparseVector" testsL
fromICTest = singleTest "fromIC" $ do -- test fromPIC, fromIMaybeC, fromNzIC
i <- genVis iTA
c <- genVis cTA
when (c /= 0) $ testViaU (SV.fromNzIC i c) (iCToU i c)
testViaU (SV.fromPIC cAG.isZero i c) (iCToU i c)
testViaU (SV.fromIMaybeC i (sJustIf (c /= 0) c)) (iCToU i c)
when (c /= 0) $ testViaY (SV.fromNzIC i c) (iCToY i c)
testViaY (SV.fromPIC cAG.isZero i c) (iCToY i c)
testViaY (SV.fromIMaybeC i (sJustIf (c /= 0) c)) (iCToY i c)
distinctAscNzsTest = singleTest "distinctAscNzs" $ do
-- test toDistinctAscNzs, fromDistinctAscNzs
v <- genVis vTA
Expand Down Expand Up @@ -148,9 +148,9 @@ tests = testGroup "SparseVector" testsL
-- :@@@

testsL = [
singleTest "not eq, vToU" $ almostInjectiveTM vTA (==) vToU,
-- testViaU is now valid
singleTest "plus" $ homomTM vTA vAG.plus (===) (+) vToU,
singleTest "not eq, vToY" $ almostInjectiveTM vTA (==) vToY,
-- testViaY is now valid
singleTest "plus" $ homomTM vTA vAG.plus (===) (+) vToY,
abelianGroupTests vTA (IsNontrivial True) vAG,
-- vAG has now been tested (by the above lines), so vTA.gen and vTA.eq have been also
testOnce "zero" $ vTA.tEq SV.zero vAG.zero, fromICTest, distinctAscNzsTest,
Expand Down
18 changes: 13 additions & 5 deletions timings/bench.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Control.Monad ((<$!>))
import Data.Bits ((.|.), complement, finiteBitSize, shift, unsafeShiftL, unsafeShiftR)
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.List (transpose)
import Data.Mod.Word (Mod)
-- import Data.Poly.Multi (toMultiPoly)
import Data.Strict.Classes (toStrict)
import qualified Data.Strict.Tuple as S
Expand Down Expand Up @@ -106,22 +107,29 @@ showNtSparse, showNtDense :: Int -> String
op2SF :: (c -> String) -> String -> (c -> String) -> c -> String
op2SF xSF opS ySF c = xSF c <> opS <> ySF c

type ModP = Mod 2_000_003
type SV = SV.VectorU ModP

benchesSV :: [Benchmark]
benchesSV = picBenches <> plusBenches
benchesSV = {- picBenches <> plusBenches <> -} scaleBenches
where
vAG = SV.mkAG intRing.ag :: AbelianGroup (SV.Vector Int)
iCToV = SV.fromPIC intRing.ag.isZero
vAG = (SV.mkAG numAG :: AbelianGroup SV) { plus = SV.plusU }
iCToV = SV.fromPIC (== 0)
makeSV g (m, n) = sumL' vAG $ take m -- sum m terms in dim n; 11 should not divide n
[iCToV (r `rem` n) (r `rem` 11 - 5) :: SV.Vector Int
[iCToV (r `rem` n) (fromIntegral (r `rem` 11 - 5)) :: SV
| r <- randomsBy (uniformR (0, 11 * n - 1)) g]
(g0, g1) = split (mkStdGen 37)
iToVs i = sum [SV.index 0 (SV.fromNzIC i n :: SV.Vector Int) i | n <- [1 .. 1000]]
iToVs i = sum [SV.index 0 (SV.fromNzIC i n :: SV) i | n <- [1 .. 1000]]
picBenches = benchWhnf iToVs (("index iCToV x1000 " <>) . show) id <$>
[10 ^ n | n <- [0 :: Int, 2, 4, 6, 12, 18]]
plusBenches = bench2Whnf vAG.plus (("Add " <>) . show) (makeSV g0) (makeSV g1) <$>
[(20, 1000), (300, 1000), (700, 1000),
(1000, 100_000), (10_000, 100_000), (30_000, 100_000),
(1000, 2 ^ (finiteBitSize (0 :: Int) - 5))]
scaleBenches = bench2Whnf SV.timesNzdCU (("Scale " <>) . show) (const 23) (makeSV g1) <$>
[(20, 1000), (300, 1000), (700, 1000),
(1000, 100_000), (10_000, 100_000), (30_000, 100_000),
(1000, 2 ^ (finiteBitSize (0 :: Int) - 5))]

benchesUPoly :: [Benchmark]
benchesUPoly = concat [plusBenches, timesBenches, divBenches]
Expand Down

0 comments on commit e8fc942

Please sign in to comment.