From fe8063c1c3e06db5e194efe7f8403b8ceddfe930 Mon Sep 17 00:00:00 2001 From: Dave Barton Date: Mon, 19 Aug 2024 13:07:05 -0700 Subject: [PATCH] Add SV.Permute, a semi-sparse permutation Also more work on VectorA & MatrixA, and use samsort. --- calculi.cabal | 11 +- src/Control/Parallel/Cooperative.hs | 31 +- src/Math/Algebra/Commutative/BinPoly.hs | 3 +- src/Math/Algebra/Commutative/Field/ZModPW.hs | 10 +- src/Math/Algebra/Commutative/GroebnerBasis.hs | 6 +- src/Math/Algebra/General/Algebra.hs | 25 +- src/Math/Algebra/General/SparseSum.hs | 3 +- .../Algebra/Linear/SparseColumnsMatrix.hs | 136 +++-- src/Math/Algebra/Linear/SparseVector.hs | 573 +++++++++++++----- test/Math/Algebra/General/TestAlgebra.hs | 14 +- test/Math/Algebra/Linear/TestSparseVector.hs | 129 ++-- timings/bench.hs | 7 - timings/time-gb.hs | 4 +- 13 files changed, 665 insertions(+), 287 deletions(-) diff --git a/calculi.cabal b/calculi.cabal index 305be89..c67a883 100644 --- a/calculi.cabal +++ b/calculi.cabal @@ -35,12 +35,14 @@ common deps build-depends: base >= 4.17 && < 5.0, containers >= 0.6 && < 0.8, + contiguous >= 0.6.1.1 && < 0.7, extra >= 1.6 && < 1.8, fmt >= 0.4 && < 0.7, + ghc-trace-events < 0.2, megaparsec < 9.7, mod >= 0.1.2 && < 0.3, parser-combinators < 1.4, - safe < 0.4, + -- safe < 0.4, strict >= 0.4 && < 0.6, text < 2.2 @@ -66,14 +68,13 @@ library build-depends: async < 2.3, - contiguous >= 0.6.1.1 && < 0.7, - ghc-trace-events < 0.2, -- hashable, unordered-containers, -- inspection-testing >= 0.3 && < 0.6, -- poly >= 0.5.1, deepseq, finite-typelits, vector-sized, primitive >= 0.8 && < 0.10, random >= 1.2.1 && < 1.3, rrb-vector >= 0.2.1 && < 0.3, + samsort < 0.2, stm >= 2.4 && < 2.6, strict-list < 0.2, time >= 1.9.3 && < 2, @@ -111,7 +112,7 @@ test-suite calculi-test Math.Algebra.Commutative.TestEPoly Math.Algebra.Commutative.TestBinPoly - build-depends: calculi, hedgehog == 1.*, mod, rrb-vector, strict-list, tasty < 1.6, + build-depends: calculi, hedgehog == 1.*, rrb-vector, strict-list, tasty < 1.6, tasty-hedgehog < 1.5 hs-source-dirs: test @@ -135,7 +136,7 @@ benchmark bench import: deps main-is: bench.hs type: exitcode-stdio-1.0 - build-depends: calculi, deepseq < 2, random >= 1.2 && < 1.3, tasty, tasty-bench < 0.4 + build-depends: calculi, deepseq < 2, random >= 1.2 && < 1.3, tasty, tasty-bench < 0.5 -- , poly, vector, vector-sized hs-source-dirs: timings ghc-options: -rtsopts -with-rtsopts=-T diff --git a/src/Control/Parallel/Cooperative.hs b/src/Control/Parallel/Cooperative.hs index 2e5481d..13d4d99 100644 --- a/src/Control/Parallel/Cooperative.hs +++ b/src/Control/Parallel/Cooperative.hs @@ -22,7 +22,8 @@ If you can guarantee your threads never block during a task, you can allocate them one-per-core (really one-per-capability) using 'parNonBlocking', for a small performance speedup. Otherwise, it's better to have some extra threads, and let them migrate between - cores, using 'parThreads'. + cores, using 'parThreads'. In either case, we don't create and destroy threads as often as + sparks do, which seems to speed things up. This module uses "Control.Concurrent.Async" to ensure that uncaught exceptions are propagated to parent threads, and orphaned threads are always killed. @@ -44,7 +45,7 @@ module Control.Parallel.Cooperative ( vecMapParChunk, rrbMapParChunk, -- * Folding and sorting - foldBalanced, foldBalancedPar, sortByPar, + foldBalanced, foldBalancedPar, sortLBy, -- sortByPar, -- * Utilities seqSpine, seqElts, inc1TVar, maybeStateTVar, popTVar, @@ -54,14 +55,18 @@ module Control.Parallel.Cooperative ( import Control.Monad (when) import Control.Monad.Extra (ifM, unlessM, whileM) +import Control.Monad.ST (runST) import Data.Bits (Bits, FiniteBits, (.&.), bit, countLeadingZeros, finiteBitSize, xor) import Data.Foldable (toList) import qualified Data.IntMap.Strict as IMS -import Data.List (sortBy, uncons, unfoldr) -import Data.List.Extra (chunksOf, mergeBy) +import Data.List (uncons, unfoldr) +import Data.List.Extra (chunksOf) -- import Data.Maybe (isJust) import Data.Maybe (isNothing) +import Data.Primitive.Array (pattern MutableArray) +import qualified Data.Primitive.Contiguous as C import qualified Data.RRBVector as RRB +import Data.SamSort (sortArrayBy) import qualified Data.Sequence as Seq import qualified Data.Vector as V import GHC.Stack (HasCallStack) @@ -315,9 +320,25 @@ foldBalancedPar f as = if null (drop 3 as) then foldBalanced f as else buddy k = (2 * lowNzBit k) `xor` k top = bit (fbTruncLog2 n) +sortLBy :: (a -> a -> Ordering) -> [a] -> [a] +{- ^ Like 'Data.List.sortBy', but faster, though not lazy (it always sorts the entire list). + 'sortLBy' currently uses [samsort](https://hackage.haskell.org/package/samsort), so it's + stable and adaptive. Also like 'Data.SamSort.sortArrayBy', 'sortLBy' will inline to get the + best performance out of statically known comparison functions. To avoid code duplication, + create a wrapping definition and reuse it as necessary. -} +sortLBy cmp xs = runST $ do + ma@(MutableArray ma') <- C.fromListMutable xs + n <- C.sizeMut ma + sortArrayBy cmp ma' 0 n + C.toListMutable ma +{-# INLINE sortLBy #-} + +{- @@@ faster to just use sortLBy (samsort) sortByPar :: Int -> (a -> a -> Ordering) -> [a] -> [a] {- ^ Strict stable sort by sorting chunks in parallel. The chunk size must be positive, and 100 appears to be a good value. The spine of the result is forced. -} sortByPar chunkSize cmp as = if null as then [] else forkJoinPar (chunksOf chunkSize) (foldBalancedPar (\es -> seqSpine . mergeBy cmp es)) - (seqSpine . sortBy cmp) as + (seqSpine . sortLBy cmp) as +-- @@@ INLINE sortByPar also, doc +-} diff --git a/src/Math/Algebra/Commutative/BinPoly.hs b/src/Math/Algebra/Commutative/BinPoly.hs index 664a5de..133a522 100644 --- a/src/Math/Algebra/Commutative/BinPoly.hs +++ b/src/Math/Algebra/Commutative/BinPoly.hs @@ -31,7 +31,6 @@ import Control.Monad.Extra (pureIf) import Data.Bits ((.&.), (.|.), bit, complement, countLeadingZeros, popCount, testBit, unsafeShiftL, unsafeShiftR) import Data.Foldable (toList) -import Data.List (sortBy) import Data.Maybe (catMaybes, fromJust) import Data.Word (Word64) import StrictList2 (pattern (:!)) @@ -115,7 +114,7 @@ data BPOtherOps ev vals = BPOtherOps { } bpSortCancel :: Cmp ev -> SL.List ev -> BinPoly ev -bpSortCancel evCmp evs = cancelRev (sortBy evCmp (SL.toListReversed evs)) SL.Nil +bpSortCancel evCmp evs = cancelRev (sortLBy evCmp (SL.toListReversed evs)) SL.Nil where cancelRev (v : t1@(w : ~t2)) r | evCmp v w == EQ = cancelRev t2 r diff --git a/src/Math/Algebra/Commutative/Field/ZModPW.hs b/src/Math/Algebra/Commutative/Field/ZModPW.hs index cc71979..71aa2f7 100644 --- a/src/Math/Algebra/Commutative/Field/ZModPW.hs +++ b/src/Math/Algebra/Commutative/Field/ZModPW.hs @@ -28,10 +28,6 @@ zzModPW = (field numAG (*) 1 fromInteger recip, balRep) in toInteger (fromIntegral (if u > maxBalRep then u - p else u) :: Int) -toWord :: Mod m -> Word --- ^ The result is @\< @m@. @@ move to Data.Mod.Word -toWord = unsafeCoerce - unsafeFromWord :: Word -> Mod m -- ^ The argument must be @\< m@. @@ move to Data.Mod.Word unsafeFromWord = unsafeCoerce @@ -42,7 +38,7 @@ newtype ModWord32 (m :: Nat) = ModW32 { w32 :: Word32 {- ^ @w32 \< m@ -} } deriving newtype (Eq, Show, Prim) modWToW32 :: Mod m -> ModWord32 m -modWToW32 = ModW32 . fromIntegral . toWord +modWToW32 = ModW32 . fromIntegral . unMod modW32ToW :: ModWord32 m -> Mod m modW32ToW = unsafeFromWord . fromIntegral . (.w32) @@ -69,7 +65,7 @@ newtype ModWord16 (m :: Nat) = ModW16 { w16 :: Word16 {- ^ @w16 \< m@ -} } deriving newtype (Eq, Show, Prim) modWToW16 :: Mod m -> ModWord16 m -modWToW16 = ModW16 . fromIntegral . toWord +modWToW16 = ModW16 . fromIntegral . unMod modW16ToW :: ModWord16 m -> Mod m modW16ToW = unsafeFromWord . fromIntegral . (.w16) @@ -96,7 +92,7 @@ newtype ModWord8 (m :: Nat) = ModW8 { w8 :: Word8 {- ^ @w8 \< m@ -} } deriving newtype (Eq, Show, Prim) modWToW8 :: Mod m -> ModWord8 m -modWToW8 = ModW8 . fromIntegral . toWord +modWToW8 = ModW8 . fromIntegral . unMod modW8ToW :: ModWord8 m -> Mod m modW8ToW = unsafeFromWord . fromIntegral . (.w8) diff --git a/src/Math/Algebra/Commutative/GroebnerBasis.hs b/src/Math/Algebra/Commutative/GroebnerBasis.hs index 31dd7c4..e6f204d 100644 --- a/src/Math/Algebra/Commutative/GroebnerBasis.hs +++ b/src/Math/Algebra/Commutative/GroebnerBasis.hs @@ -22,7 +22,7 @@ import Control.Monad.Extra (ifM, orM, whenM) import Data.Bits ((.&.), (.|.)) import Data.Foldable (find, minimumBy, toList) import Data.Int (Int64) -import Data.List (elemIndex, findIndices, groupBy, sortBy) +import Data.List (elemIndex, findIndices, groupBy) import Data.List.Extra (chunksOf) import Data.Maybe (catMaybes, fromJust, isJust, isNothing, listToMaybe, mapMaybe) import qualified Data.RRBVector as GBV @@ -250,7 +250,7 @@ updatePairs (GBPolyOps { nVars, evCmp, extraSPairs, useSugar }) gMGis ijcs tGi itcs = catMaybes (zipWith (\i -> fmap (giToSp i t)) [0..] itMGis) :: [SPair ev] -- "sloppy" sugar method: itcss = TS.measurePure "1.1.2 sort/group new itcs" $ seqElts $ - groupBy (cmpEq lcmCmp) (sortBy lcmCmp itcs) + groupBy (cmpEq lcmCmp) (sortLBy lcmCmp itcs) itcsToC = (.m) . head itcss' = TS.measurePure "1.1.3 M_ik" $ seqElts $ filter (noDivs . itcsToC) itcss where -- criterion M_ik; 3 seqElts calls for TS.measurePure @@ -713,7 +713,7 @@ groebnerBasis gbpA@(GBPolyOps { .. }) gbTrace gbi0 newGens = TS.scope $ do pure True doSP = maybe (pure False) newG =<< setPop ijcsRef mapM_ (\g -> addGenHN =<< reduce_n (EPolyHDeg g (homogDeg0 g))) - (sortBy (evCmp `on` leadEvNz) (filter (not . pIsZero) newGens)) + (sortLBy (evCmp `on` leadEvNz) (filter (not . pIsZero) newGens)) numSleepingVar <- newTVarIO (0 :: Int) let traceTime = do cpuTime2 <- getCPUTime diff --git a/src/Math/Algebra/General/Algebra.hs b/src/Math/Algebra/General/Algebra.hs index db77e96..92b3159 100644 --- a/src/Math/Algebra/General/Algebra.hs +++ b/src/Math/Algebra/General/Algebra.hs @@ -59,6 +59,7 @@ module Math.Algebra.General.Algebra ( Cmp, cmpEq, maxBy, minBy, + isSortedBy, -- * Monoids and Groups -- $monoids @@ -96,7 +97,7 @@ module Math.Algebra.General.Algebra ( -- * Basic numeric rings numAG, numRing, -- ** Integer - zzAG, zzDiv, zzRing, + zzAG, zzDiv, zzRing, intRing, -- ** Double dblAG, dblRing, @@ -123,7 +124,7 @@ module Math.Algebra.General.Algebra ( #if ! MIN_VERSION_base(4, 18, 0) liftA2, #endif - assert + assert, sortLBy ) where import GHC.Records @@ -132,6 +133,7 @@ import GHC.Records import Control.Applicative (liftA2) -- unnecesary in base 4.18+, since in Prelude #endif import Control.Exception (assert) +import Control.Parallel.Cooperative (sortLBy) import Data.Bifunctor (bimap, second) import Data.Char (isDigit) #if ! MIN_VERSION_base(4, 20, 0) @@ -243,6 +245,16 @@ minBy :: Cmp a -> Op2 a -- ^ > minBy cmp x y = if cmp x y /= GT then x else y minBy cmp x y = if cmp x y /= GT then x else y +-- | The 'isSortedBy' function returns 'True' iff the predicate returns true +-- for all adjacent pairs of elements in the list. +isSortedBy :: (a -> a -> Bool) -> [a] -> Bool +-- from Data.List.Ordered in data-ordlist +isSortedBy lte = loop + where + loop [] = True + loop [_] = True + loop (x:y:zs) = (x `lte` y) && loop (y:zs) + -- * Monoids and Groups @@ -642,6 +654,15 @@ zzRing :: Ring Integer -- ^ the ring of integers ℤ zzRing = numRing integralDomainFlags zzDiv +intRing :: Ring Int +{- ^ Arithmetic mod @2^n@, where an @Int@ has @n@ bits. Division is simply @Int@ division. This + is mostly used for testing and benchmarks. -} +intRing = numRing rFlags (const intDiv) + where + rFlags = RingFlags { commutative = True, noZeroDivisors = False, nzInverses = False } + intDiv y 0 = (0, y) + intDiv y m = quotRem y m + -- ** Double dblAG :: AbelianGroup Double diff --git a/src/Math/Algebra/General/SparseSum.hs b/src/Math/Algebra/General/SparseSum.hs index b2b39f7..c8c29a3 100644 --- a/src/Math/Algebra/General/SparseSum.hs +++ b/src/Math/Algebra/General/SparseSum.hs @@ -22,7 +22,6 @@ import Math.Algebra.Category.Category import Data.Bifunctor (Bifunctor(first, second, bimap)) import Data.Foldable (toList) -import Data.List (sortBy) import Data.Maybe (fromJust) import GHC.Stack (HasCallStack) import StrictList2 (pattern (:!)) @@ -157,7 +156,7 @@ ssAGUniv (AbelianGroup cFlags eq plus _zero isZero neg) dCmp = ssFoldSort :: AbelianGroup c -> Cmp d -> [SSTerm c d] -> SparseSum c d -- ^ sorts and combines the terms; input terms may have coefs which are 0 ssFoldSort (AbelianGroup _ _ cPlus _ cIsZero _) dCmp cds0 = - go ssZero (sortBy (dCmp `on` (.d)) cds0) + go ssZero (sortLBy (dCmp `on` (.d)) cds0) where go done [] = done go done (cd : t) = go1 done cd.d cd.c t diff --git a/src/Math/Algebra/Linear/SparseColumnsMatrix.hs b/src/Math/Algebra/Linear/SparseColumnsMatrix.hs index bce6c11..8aec2d5 100644 --- a/src/Math/Algebra/Linear/SparseColumnsMatrix.hs +++ b/src/Math/Algebra/Linear/SparseColumnsMatrix.hs @@ -1,18 +1,19 @@ {-# LANGUAGE Strict #-} -{- | A sparse columns 'Matrix' is implemented as a 'SV.Vector' of 'SV.Vector's. +{- | A sparse columns 'Matrix' is implemented as a 'SV.Vector' of 'SV.Vector's or + 'SV.VectorU's. This data structure is also fairly efficient for dense matrices. This module uses LANGUAGE Strict. In particular, constructor fields and function arguments - are strict unless marked with a ~. Also, a 'Matrix' is strict in both its spine and its - elements. Finally, "Data.Strict.Maybe" and "Data.Strict.Tuple" may be often used. + are strict unless marked with a ~. Also, a 'Matrix' or 'MatrixA' is strict in both its spine + and its elements. Finally, "Data.Strict.Maybe" and "Data.Strict.Tuple" may be often used. -} module Math.Algebra.Linear.SparseColumnsMatrix ( -- * Matrix - Matrix, - -- * Inlinable function(s) + MatrixA, Matrix, MatrixU, + -- * Inlinable or worker-wrapper function(s) getCol, -- * Other functions PLUQMats(..), MatrixOps(..), matrixOps, transpose @@ -21,9 +22,9 @@ module Math.Algebra.Linear.SparseColumnsMatrix ( import Math.Algebra.General.Algebra import Math.Algebra.Linear.SparseVector as SV -import Control.Monad.Extra (pureIf) -import Data.Maybe (fromMaybe) -import Data.Strict.Classes (toLazy) +import qualified Data.Primitive.Contiguous as C +import Data.Primitive.PrimArray (PrimArray) +import Data.Primitive.SmallArray (SmallArray) import qualified Data.Strict.Maybe as S import qualified Data.Strict.Tuple as S import Data.Strict.Tuple ((:!:), pattern (:!:)) @@ -31,82 +32,85 @@ import Data.Strict.Tuple ((:!:), pattern (:!:)) -- * Matrix -type Matrix c = SV.Vector (SV.Vector c) -{- ^ a matrix stored as columns, implementing a linear map between finitely generated right +type MatrixA arr c = SV.Vector (SV.VectorA arr c) +{- ^ A matrix stored as columns, implementing a linear map between finitely generated right R-modules with basis elements indexed starting at 0. The columns are the images of the - source module's basis elements. -} + source module's basis elements. @(arr c)@ is a small array of nonzero @c@s. -} --- * Inlinable function(s) +type Matrix c = MatrixA SmallArray c +-- ^ A matrix of boxed elements. -getCol :: Matrix c -> Int -> SV.Vector c +type MatrixU c = VectorA PrimArray c +-- ^ A matrix of unboxed elements. A common example is integers modulo a single-word prime. + +-- * Inlinable or worker-wrapper function(s) + +getCol :: C.Contiguous arr => MatrixA arr c -> Int -> SV.VectorA arr c getCol = SV.index SV.zero -- * Other functions -data PLUQMats c = PLUQMats { pi, p, l, u, q, qi :: Matrix c, r :: Int } --- ^ A matrix factorization @p * l * u * q@ where: --- * @p@ and @q@ are permutation matrices --- * @pi@ and @qi@ are the inverses of @p@ and @q@ --- * @l@ and @u@ are /lower/ and /upper trapezoidal/ respectively --- (@i \< j => l[i, j] = u[j, i] = 0@) --- * @0 \<= i \< r => l[i, i]@ and @u[i, i]@ are units --- * @r \<= i => l[j, i] = u[i, j] = 0@ --- --- Descriptively, @l@ has @r@ columns and @u@ has @r@ rows, their main diagonal elements are --- units, and they are 0 before that. +data PLUQMats arr c = + PLUQMats { p :: Permute, l, u :: MatrixA arr c, q :: Permute, r :: Int } +{- ^ A matrix factorization @p * l * u * q@ where: + + * @p@ and @q@ are permutations + * @l@ and @u@ are /lower/ and /upper trapezoidal/ respectively + (@i \< j => l[i, j] = u[j, i] = 0@) + * @0 \<= i \< r => l[i, i]@ and @u[i, i]@ are units + * @r \<= i => l[j, i] = u[i, j] = 0@ for all @j@ + + Descriptively, @l@ has @r@ columns and @u@ has @r@ rows, their main diagonal elements are + units, and they are 0 before that. -} -data MatrixOps c = MatrixOps { - vModR :: ModR c (SV.Vector c), - matTimesV :: Matrix c -> Op1 (SV.Vector c), - matRing :: Ring (Matrix c), - diagNMat :: Int -> c -> Matrix c, -- ^ create a diagonal nxn matrix - upperTriSolve :: Matrix c -> Op1 (SV.Vector c), - -- ^ @upperTriSolve m v@ returns the unique @w@ such that @m * w = v@. @m@ must be upper - -- triangular, and its main diagonal elements must be units starting when @v@ becomes - -- nonzero. That is, if @v[i]@ is nonzero and @i \<= k@, then @m[k, k]@ must be a unit. - lowerTrDivBy :: Matrix c -> SV.Vector c -> (SV.Vector c, SV.Vector c), - -- ^ @lowerTrDivBy m v = (q, r) => v = m*q + r@ and @rows[0 .. t](r) = 0@; where @m@ - -- is lower trapezoidal, has max column @t@, and its main diagonal elements are units - -- starting when @v@ becomes nonzero. That is, if @v[i]@ is nonzero and @i \<= k@, then - -- @m[k, k]@ must be a unit. - lowerTriSolve :: Matrix c -> Op1 (SV.Vector c) - -- ^ @lowerTriSolve m v@ returns the unique @w@ such that @m * w = v@. @m@ must be lower - -- triangular, and its main diagonal elements must be units starting when @v@ becomes - -- nonzero. That is, if @v[i]@ is nonzero and @i \<= k@, then @m[k, k]@ must be a unit. +{- | @upperTriSolve m v@ returns the unique @w@ such that @m * w = v@. @m@ must be upper + triangular (hence square), and its main diagonal elements must be units starting when @v@ + becomes nonzero. That is, if @v[i]@ is nonzero and @i \<= k@, then @m[k, k]@ must be a + unit. (We also require @w@ to be zero before @v@ first becomes nonzero.) + + @lowerTrDivBy m v = (q, r) => v = m*q + r@ and @rows[0 .. t](r) = 0@; where @m@ is lower + trapezoidal, has max column @t@, and its main diagonal elements are units starting when @v@ + becomes nonzero. That is, if @v[i]@ is nonzero and @i \<= k \<= t@, then @m[k, k]@ must be a + unit. + + @lowerTriSolve m v@ returns the unique @w@ such that @m * w = v@. @m@ must be lower + triangular, and its main diagonal elements must be units starting when @v@ becomes nonzero. + That is, if @v[i]@ is nonzero and @i \<= k@, then @m[k, k]@ must be a unit. (We also require + @w@ to be zero before @v@ first becomes nonzero.) -} +data MatrixOps arr c = MatrixOps { + vModR :: ModR c (SV.VectorA arr c), + matTimesV :: MatrixA arr c -> Op1 (SV.VectorA arr c), + matRing :: Ring (MatrixA arr c), + diagNMat :: Int -> (Int -> c) -> MatrixA arr c, -- ^ create a diagonal nxn matrix + upperTriSolve :: MatrixA arr c -> Op1 (SV.VectorA arr c), + lowerTrDivBy :: MatrixA arr c -> SV.VectorA arr c -> + (SV.VectorA arr c, SV.VectorA arr c), + lowerTriSolve :: MatrixA arr c -> Op1 (SV.VectorA arr c) } -matrixOps :: forall c. Ring c -> Int -> MatrixOps c -{- ^ ring of matrices. @one@ and @fromZ@ of @matrixOps cR n@ will create @n x n@ matrices. +matrixOps :: forall c arr. (C.Contiguous arr, C.Element arr c) => + Ring c -> ModR c (VectorA arr c) -> Int -> MatrixOps arr c +{- ^ Ring of matrices. @one@ and @fromZ@ of @matrixOps cR vModR n@ will create @n x n@ matrices. @matRing.ag.monFlags.nontrivial@ assumes @n > 0@. -} -matrixOps cR maxN = MatrixOps { .. } +matrixOps cR vModR maxN = MatrixOps { .. } where cIsZero = cR.isZero - vAG = SV.mkAG cR.ag + vAG = vModR.ag cFlags = cR.rFlags - cNzds = cFlags.noZeroDivisors - timesNzC = (if cNzds then SV.timesNzdC else SV.timesC) cR - timesNzdsC c v = if cIsZero c then SV.zero else SV.timesNzdC cR c v + timesNzC = vModR.scale -- use timesNzdCU when possible for speed minusTimesNzC v q w = vAG.plus v $! timesNzC (cR.neg q) w - vBDiv doFull v w = fromMaybe (cR.zero, v) $ do - i :!: wc <- toLazy $ SV.headPairMaybe w - vc <- if doFull.b then toLazy $ SV.indexMaybe v i else - do vi :!: c <- toLazy $ SV.headPairMaybe v; pureIf (vi == i) c - let (q, _) = cR.bDiv doFull vc wc - pureIf (not (cIsZero q)) (q, minusTimesNzC v q w) - vModR = Module vAG (if cNzds then timesNzdsC else SV.timesC cR) vBDiv matAG = SV.mkAG vAG matTimesV = flip (SV.dotWith vAG timesNzC) isTrivial = maxN == 0 || cIsZero cR.one matFlags = if maxN == 1 then cFlags else RingFlags { commutative = isTrivial, noZeroDivisors = isTrivial, nzInverses = isTrivial } a *~ b = SV.mapC SV.isZero (matTimesV a) b - one = diagNMat maxN cR.one - fromZ z = diagNMat maxN (cR.fromZ z) + one = diagNMat maxN (const cR.one) + fromZ z = diagNMat maxN (const (cR.fromZ z)) matRing = Ring matAG matFlags (*~) one fromZ bDiv - diagNMat n c = if cIsZero c then SV.zero else - SV.fromDistinctAscNzs [i :!: SV.fromNzIC i c | i <- [0 .. n - 1]] + diagNMat n f = SV.fromDistinctAscNzs + [i :!: SV.fromNzIC i c | i <- [0 .. n - 1], let c = f i, not (cIsZero c)] upperTriSolve m = loop [] where - -- @@@: loop done v = case SV.lastPairMaybe v of S.Nothing -> SV.fromDistinctAscNzs done S.Just (i :!: c) -> loop ((i :!: q) : done) v2 @@ -129,10 +133,14 @@ matrixOps cR maxN = MatrixOps { .. } in loop ((i :!: q) : done) v2 where ~end = (SV.fromDistinctAscNzs (reverse done), v) - lowerTriSolve m v = assert (SV.isZero r) q + lowerTriSolve m v = if SV.isZero r then q else error "lowerTriSolve: illegal inputs" where (q, r) = lowerTrDivBy m v bDiv _doFull y _t = (SV.zero, y) -- @@ improve (incl. solving linear equations in parallel) +{-# SPECIALIZE matrixOps :: Ring c -> ModR c (Vector c) -> Int -> MatrixOps SmallArray c #-} -transpose :: Op1 (Matrix c) -transpose = SV.foldBIMap' SV.join SV.zero (SV.mapNzFC . SV.fromNzIC) +transpose :: (C.Contiguous arr, C.Element arr c) => Op1 (MatrixA arr c) +-- ^ Transpose an mxn matrix into an nxm one, swapping elements (i, j) and (j, i). +transpose = SV.foldBIMap' (SV.unionWith (const False) SV.join) SV.zero + (SV.mapNzFC . SV.fromNzIC) +{-# SPECIALIZE transpose :: Op1 (Matrix c) #-} diff --git a/src/Math/Algebra/Linear/SparseVector.hs b/src/Math/Algebra/Linear/SparseVector.hs index 4eae961..7c48b73 100644 --- a/src/Math/Algebra/Linear/SparseVector.hs +++ b/src/Math/Algebra/Linear/SparseVector.hs @@ -28,43 +28,49 @@ module Math.Algebra.Linear.SparseVector ( -- * Vector VectorA, Vector, VectorU, check, -- * Create - zero, fromPIC, fromIMaybeC, fromNzIC, fromDistinctAscNzs, + zero, fromPIC, fromIMaybeC, fromNzIC, fromDistinctAscNzs, fromDistinctAscNzsNoCheck, + fromDistinctNzs, fromNzs, -- * Query isZero, index, indexMaybe, size, headPairMaybe, headPair, lastPairMaybe, lastPair, -- * Fold - foldBIMap', iFoldR, iFoldL, toDistinctAscNzs, + foldBIMap', iFoldR, iFoldL, keys, toDistinctAscNzs, -- * Map mapC, mapNzFC, mapCMaybeWithIndex, - -- * Modify/Permute - vApply, invPermute, swap, + -- * Zip/Combine + unionWith, plusU, unionDisj, vApply, andNot, foldLIntersect', -- * Split/Join split, join, -- * Addition - unionWith, plusU, mkAG, + mkAG, mkAGU, -- * Multiplication - dotWith, timesNzdC, timesNzdCU, timesC, monicizeUnit, + dotWith, timesNzdC, timesNzdCU, timesC, monicizeUnit, mkModR, mkModRU, + -- * Permute + Permute(to, from), pToF, pIdent, pSwap, pCycle, fToP, injToPermute, pCompose, pGroup, + permuteV, swap, sortPermute, -- * I/O showPrec ) where import Math.Algebra.General.Algebra +import Control.Monad.Extra (pureIf) import Control.Monad.ST (ST, runST) import Control.Parallel.Cooperative (fbTruncLog2, lowNzBit) -import Data.Bits ((.&.), (.|.), (.^.), (!<<.), (!>>.), countTrailingZeros, finiteBitSize, - popCount) -- @@ is countLeadingZeros faster on ARM? +import Data.Bits ((.&.), (.|.), (.^.), (!<<.), (!>>.), complement, countTrailingZeros, + finiteBitSize, popCount) -- @@ is countLeadingZeros faster on ARM? import Data.Functor.Classes (liftEq) -import Data.Mod.Word (Mod) +import Data.Maybe (fromMaybe) import qualified Data.Primitive.Contiguous as C import Data.Primitive.PrimArray (PrimArray) import Data.Primitive.SmallArray (SmallArray) import Data.Primitive.Types (Prim) +import Data.Strict.Classes (toLazy) import qualified Data.Strict.Maybe as S import qualified Data.Strict.Tuple as S import Data.Strict.Tuple ((:!:), pattern (:!:)) import Data.Word (Word64) import Fmt ((+|), (|+)) -import GHC.TypeNats (KnownNat) +import GHC.Stack (HasCallStack) nothingIf :: Pred a -> a -> S.Maybe a -- move and export? @@ -74,8 +80,13 @@ nothingIf p a = if p a then S.Nothing else S.Just a b64 :: Int -> Word64 -- unsafe 'bit'; 0 <= i < 64 b64 i = 1 !<<. i -b64Index :: Word64 -> Word64 -> Int -- b is a bit; position in nonzero bits or -1 -b64Index b bs = if b .&. bs == 0 then -1 else popCount ((b - 1) .&. bs) +b64Index :: Word64 -> Word64 -> Int -- b is a bit; position in nonzero bits +b64Index b bs = popCount ((b - 1) .&. bs) +{-# INLINE b64Index #-} + +b64IndexMaybe :: Word64 -> Word64 -> S.Maybe Int -- like elemIndex for bits +b64IndexMaybe b bs = if b .&. bs /= 0 then S.Just (b64Index b bs) else S.Nothing +{-# INLINE b64IndexMaybe #-} -- @@ use C.unsafeShrinkAndFreeze after its next release, when it shrinks in place: @@ -178,12 +189,22 @@ fromNzIC i c = if i < 0 then error "fromNzIC: negative index" else (svv (b64 (j .&. 63)) w2 (C.singleton v)) {-# SPECIALIZE fromNzIC :: Int -> c -> Vector c #-} -fromDistinctAscNzs :: forall arr c. (C.Contiguous arr, C.Element arr c) => +fromDistinctAscNzs :: forall arr c. (C.Contiguous arr, C.Element arr c, HasCallStack) => [Int :!: c] -> VectorA arr c {- ^ The 'Int's must be distinct and ascending, and the @c@s must be nonzero. Usually \(n\) steps, though up to \((d - log_{64} n) n\) if the vector is very sparse. -} -fromDistinctAscNzs [] = zero -fromDistinctAscNzs ((i0 :!: c0) : t0) = runST $ do +fromDistinctAscNzs ps = + if isSortedBy (\p q -> S.fst p < S.fst q) ps then fromDistinctAscNzsNoCheck ps + else error "fromDistinctAscNzs: indices not distinct and ascending" +{-# SPECIALIZE fromDistinctAscNzs :: [Int :!: c] -> Vector c #-} + +fromDistinctAscNzsNoCheck :: forall arr c. (C.Contiguous arr, C.Element arr c) => + [Int :!: c] -> VectorA arr c +{- ^ Like 'fromDistinctAscNzs' but the input list is not checked. If its indices are not + distinct and ascending, undefined behavior (due to array bounds errors) may result. -} +fromDistinctAscNzsNoCheck [] = zero +fromDistinctAscNzsNoCheck ((i0 :!: c0) : t0) = + if i0 < 0 then error "fromDistinctAscNzs: negative index" else runST $ do sveBuf <- C.new @arr 64 let mkSve bs j i c ics = do C.write sveBuf j c @@ -212,7 +233,19 @@ fromDistinctAscNzs ((i0 :!: c0) : t0) = runST $ do mkSvv 0 iW2 nzts 0 i c ics trunc6 n = n - n `rem` 6 -- same as n `quot` 6 * 6 S.fst <$> mkSV (trunc6 (fbTruncLog2 (maxBound :: Int))) i0 c0 t0 -{-# SPECIALIZE fromDistinctAscNzs :: [Int :!: c] -> Vector c #-} +{-# SPECIALIZE fromDistinctAscNzsNoCheck :: [Int :!: c] -> Vector c #-} + +fromDistinctNzs :: forall arr c. (C.Contiguous arr, C.Element arr c) => + [Int :!: c] -> VectorA arr c +{- ^ The 'Int's must be distinct, and the @c@s must be nonzero. \(O(n log_2 n)\). -} +fromDistinctNzs = fromDistinctAscNzs . sortLBy (compare `on` S.fst) + -- @@ parallelize? +{-# SPECIALIZE fromDistinctNzs :: [Int :!: c] -> Vector c #-} + +fromNzs :: forall arr c. (C.Contiguous arr, C.Element arr c) => [c] -> VectorA arr c +{- ^ The @c@s must be nonzero. \(n\) steps. -} +fromNzs = fromDistinctAscNzsNoCheck . S.zip [0 ..] +{-# SPECIALIZE fromNzs :: [c] -> Vector c #-} -- * Query @@ -232,14 +265,11 @@ indexMaybe vRoot iRoot | iRoot < 0 = error ("SV.indexMaybe: negative index " <> show iRoot) | otherwise = go vRoot iRoot where - go (SVE bs nzs) i = if i > 63 || j == -1 then S.Nothing else S.Just (C.index nzs j) - where - ~j = b64Index (b64 i) bs - go (SVV bs iW2 nzts) i = if i0 > 63 || j == -1 then S.Nothing else - go (C.index nzts j) (i .&. ((1 !<<. iW2) - 1)) + i0ToMJ i0 bs = if i0 > 63 then S.Nothing else b64IndexMaybe (b64 i0) bs + go (SVE bs nzs) i = C.index nzs <$> i0ToMJ i bs + go (SVV bs iW2 nzts) i = S.maybe S.Nothing jF (i0ToMJ (i !>>. iW2) bs) where - i0 = i !>>. iW2 - ~j = b64Index (b64 i0) bs + jF j = go (C.index nzts j) (i .&. ((1 !<<. iW2) - 1)) {-# SPECIALIZE indexMaybe :: Vector c -> Int -> S.Maybe c #-} size :: (C.Contiguous arr, C.Element arr c) => VectorA arr c -> Int @@ -343,6 +373,11 @@ iFoldL f = go 0 go start ~z (SVV bs iW2 nzts) = aIFoldL go z start bs iW2 nzts {-# SPECIALIZE iFoldL :: (Int -> t -> c -> t) -> t -> Vector c -> t #-} +keys :: (C.Contiguous arr, C.Element arr c) => VectorA arr c -> [Int] +-- ^ Non-missing indices of the vector, in increasing order. \(m\) steps. +keys = iFoldR (\i _c -> (i :)) [] +{-# SPECIALIZE keys :: Vector c -> [Int] #-} + toDistinctAscNzs :: (C.Contiguous arr, C.Element arr c) => VectorA arr c -> [Int :!: c] -- ^ @toDistinctAscNzs = iFoldR (\i c -> ((i :!: c) :)) []@. \(m\) steps. toDistinctAscNzs = iFoldR (\i c -> ((i :!: c) :)) [] @@ -361,7 +396,7 @@ mkSv _ svvF (SVV bs iW2 nzts) = svv bs' iW2 nzts' combineSv :: C.Contiguous arr2 => (Word64 -> arr0 c0 -> Word64 -> arr1 c1 -> Word64 :!: arr2 c2) -> - (Word64 -> SmallArray (VectorA arr0 c0) -> Word64 -> SmallArray (VectorA arr1 c1) -> + (Int -> Word64 -> SmallArray (VectorA arr0 c0) -> Word64 -> SmallArray (VectorA arr1 c1) -> Word64 :!: SmallArray (VectorA arr2 c2)) -> VectorA arr0 c0 -> VectorA arr1 c1 -> VectorA arr2 c2 {- For speed, the caller may want to check isZero, or getIW2 x < or > getIW2 y. Else note svvF @@ -370,11 +405,11 @@ combineSv sveF svvF = go where go (SVE bs0 nzs0) (SVE bs1 nzs1) = S.uncurry SVE $ sveF bs0 nzs0 bs1 nzs1 go x y - | getIW2 x > getIW2 y = go x (SVV 1 (getIW2 y + 6) (C.singleton y)) - | getIW2 x < getIW2 y = go (SVV 1 (getIW2 x + 6) (C.singleton x)) y + | getIW2 x > getIW2 y = go x (SVV 1 (getIW2 x) (C.singleton y)) + | getIW2 x < getIW2 y = go (SVV 1 (getIW2 y) (C.singleton x)) y go (SVV bs0 iW2 nzts0) (SVV bs1 _ nzts1) = svv bs2 iW2 nzts2 where - bs2 :!: nzts2 = svvF bs0 nzts0 bs1 nzts1 + bs2 :!: nzts2 = svvF iW2 bs0 nzts0 bs1 nzts1 go _ _ = undefined aMapC :: (C.Contiguous arr, C.Element arr c, C.Contiguous arr', C.Element arr' c') => @@ -382,7 +417,7 @@ aMapC :: (C.Contiguous arr, C.Element arr c, C.Contiguous arr', C.Elem -- assumes @popCount bs == C.size nzs@ aMapC is0 f bs nzs = assert (popCount bs == C.size nzs) $ runST $ do nzs' <- C.new (C.size nzs) - let go 0 bs' _ j' = (bs' :!: ) <$> unsafeShrinkAndFreeze nzs' j' + let go 0 bs' _ j' = (bs' :!:) <$> unsafeShrinkAndFreeze nzs' j' go bsTodo bs' j j' = if is0 c' then go (bsTodo .^. b) (bs' .^. b) (j + 1) j' else do C.write nzs' j' c' go (bsTodo .^. b) bs' (j + 1) (j' + 1) @@ -415,7 +450,7 @@ aMapCMaybeWithIndex :: (C.Contiguous arr, C.Element arr c, C.Contiguous arr' Word64 -> Int -> arr c -> Word64 :!: arr' c' aMapCMaybeWithIndex f start bs iW2 nzs = assert (popCount bs == C.size nzs) $ runST $ do nzs' <- C.new (C.size nzs) - let go 0 bs' _ j' = (bs' :!: ) <$> unsafeShrinkAndFreeze nzs' j' + let go 0 bs' _ j' = (bs' :!:) <$> unsafeShrinkAndFreeze nzs' j' go bsTodo bs' j j' = case mc' of S.Nothing -> go (bsTodo .^. b) (bs' .^. b) (j + 1) j' S.Just c' -> do @@ -426,6 +461,7 @@ aMapCMaybeWithIndex f start bs iW2 nzs = assert (popCount bs == C.size nzs) $ r mc' = f !$ start + i0 !<<. iW2 !$ C.index nzs j b = b64 i0 go bs bs 0 0 +{-# INLINABLE aMapCMaybeWithIndex #-} {-# SPECIALIZE aMapCMaybeWithIndex :: (Int -> c -> S.Maybe c') -> Int -> Word64 -> Int -> SmallArray c -> Word64 :!: SmallArray c' #-} @@ -441,108 +477,10 @@ mapCMaybeWithIndex f = fromMaybeSv . go 0 go start (SVV bs iW2 nzts) = toMaybeSv $ svv bs' iW2 nzts' where bs' :!: nzts' = aMapCMaybeWithIndex go start bs iW2 nzts +{-# INLINABLE mapCMaybeWithIndex #-} {-# SPECIALIZE mapCMaybeWithIndex :: (Int -> c -> S.Maybe c') -> Vector c -> Vector c' #-} --- * Modify/Permute - -aVApply :: (C.Contiguous dArr, C.Element dArr d, C.Contiguous cArr, C.Element cArr c) => - (d -> Op1 (S.Maybe c)) -> Word64 -> dArr d -> Word64 -> cArr c -> - Word64 :!: cArr c -aVApply f bs0 nzs0 bs1 nzs1 = runST $ do - let bsAll = bs0 .|. bs1 - nzs2 <- C.new (popCount bsAll) - let go 0 bs2 _ _ i2 = (bs2 :!: ) <$> unsafeShrinkAndFreeze nzs2 i2 - go bsTodo bs2 i0 i1 i2 - | bs0 .&. b == 0 = do - C.write nzs2 i2 $! C.index nzs1 i1 - go bsTodo' bs2 i0 (i1 + 1) (i2 + 1) - | bs1 .&. b == 0 = goD S.Nothing i1 - | otherwise = goD (S.Just (C.index nzs1 i1)) (i1 + 1) - where - b = lowNzBit bsTodo - bsTodo' = bsTodo .^. b - goD mc1 i1' = do - let mc = f !$ C.index nzs0 i0 !$ mc1 - case mc of - S.Nothing -> go bsTodo' (bs2 .^. b) (i0 + 1) i1' i2 - S.Just c -> do - C.write nzs2 i2 c - go bsTodo' bs2 (i0 + 1) i1' (i2 + 1) - go bsAll bsAll 0 0 0 -{-# SPECIALIZE aVApply :: (d -> Op1 (S.Maybe c)) -> Word64 -> SmallArray d -> - Word64 -> SmallArray c -> Word64 :!: SmallArray c #-} - -vApply :: (C.Contiguous dArr, C.Element dArr d, C.Contiguous cArr, C.Element cArr c) => - ({- @@@@ Int -> -} d -> Op1 (S.Maybe c)) -> VectorA dArr d -> Op1 (VectorA cArr c) -{- ^ @vApply f ds v@ applies @f {- @@@ i -} ds[i]@ to @v[i]@, for each non-missing element in @ds@. - Usually \(m + n\) steps, though \(m\) if the input trees have few common nodes, or up to - \(64 (d - log_{64} m) m\) if the first input is very sparse and the second input is very - dense. -} -vApply f ds0 v0 = fromMaybeSv $ go 0 ds0 (toMaybeSv v0) - where - go start ds mv - | S.Nothing <- mv = toMaybeSv $ - mapCMaybeWithIndex (\i d -> f {- @@@ (start + i) -} d S.Nothing) ds - | isZero ds = mv - | S.Just v <- mv = toMaybeSv $ combineSv (aVApply f) (aVApply (go start)) ds v -{-# SPECIALIZE vApply :: (d -> Op1 (S.Maybe c)) -> Vector d -> Op1 (Vector c) #-} - -invPermute :: (C.Contiguous arr, C.Element arr c) => VectorU Int -> Op1 (VectorA arr c) -{- ^ @invPermute js v@ applies the inverse of the sparse permutation @js@, moving each - @v[js[i]]@ to index @i@ in the result. If @js[i]@ is missing in @js@, then @v[i]@ is used. - Thus the result is the same as @vApply (\j _ -> indexMaybe v j) js v@. \(d m\) steps if @js@ - is dense, or up to \(d m + 64 (d - log_{64} m) m\) steps if @js@ is very sparse and @v@ is - very dense. -} -invPermute js v = vApply (\j _ -> indexMaybe v j) js v -{-# SPECIALIZE invPermute :: VectorU Int -> Op1 (Vector c) #-} -{-# SPECIALIZE invPermute :: VectorU Int -> Op1 (VectorU Int) #-} - -swap :: (C.Contiguous arr, C.Element arr c) => Int -> Int -> Op1 (VectorA arr c) --- ^ swap two coordinates. Up to \(128 d\) steps for a very dense vector. -swap i j - | i < j = invPermute (fromDistinctAscNzs [i :!: j, j :!: i]) - | i > j = swap j i - | otherwise = id -{-# SPECIALIZE swap :: Int -> Int -> Op1 (Vector c) #-} - --- * Split/Join - -split :: (C.Contiguous arr, C.Element arr c) => Int -> VectorA arr c -> - VectorA arr c :!: VectorA arr c -{- ^ @split i v@ splits @v@ into parts with indices @\< i@, and @>= i@. Up to \(64 d\) steps - for a dense @v@, or fewer if @i@ is a multiple of a high power of 2. -} -split i = go (max i 0) - where - go k w - | w.bs == lowBs = w :!: zero - | k == 0 = zero :!: w - | SVE bs nzs <- w = SVE lowBs (C.clone (C.slice nzs 0 j)) - :!: SVE (bs .^. lowBs) (C.clone (C.slice nzs j (n1s - j))) - | SVV bs iW2 nzts <- w, - let highBs = bs .&. (negate 2 !<<. k0) - v0 = svv lowBs iW2 (C.clone (C.slice nzts 0 j)) - j1 = if bs .&. b == 0 then j else j + 1 - v3 = svv highBs iW2 (C.clone (C.slice nzts j1 (n1s - j1))) - = if bs .&. b == 0 then v0 :!: v3 else - let v1 :!: v2 = go (k - k0 !<<. iW2) (C.index nzts j) - raise v = if isZero v then v else svv b iW2 (C.singleton v) - in join v0 (raise v1) :!: join (raise v2) v3 - where - k0 = k !>>. getIW2 w - b = if k0 >= 64 then 0 else b64 k0 - lowBs = w.bs .&. (b - 1) - ~j = popCount lowBs - ~n1s = popCount w.bs -{-# SPECIALIZE split :: Int -> Vector c -> Vector c :!: Vector c #-} - -join :: (C.Contiguous arr, C.Element arr c) => Op2 (VectorA arr c) -{- ^ Concatenate two vectors, e.g. undoing a 'split'. The indices in the first must be smaller - than the indices in the second. Up to \(64 d\) steps for a dense vector, or fewer if the - parts were split at a multiple of a high power of 2. -} -join = unionWith (const False) (\_ _ -> error "SV.join: Non-disjoint vectors") -{-# SPECIALIZE join :: Op2 (Vector c) #-} - --- * Addition +-- * Zip/Combine aUnionWith :: (C.Contiguous arr, C.Element arr c) => Pred c -> Op2 c -> Word64 -> arr c -> Word64 -> arr c -> Word64 :!: arr c @@ -553,7 +491,7 @@ aUnionWith _ _ bs0 nzs0 bs1 nzs1 -- to make SV.join fast aUnionWith is0 f bs0 nzs0 bs1 nzs1 = runST $ do let bsAll = bs0 .|. bs1 nzs2 <- C.new (popCount bsAll) - let go 0 bs2 _ _ j2 = (bs2 :!: ) <$> unsafeShrinkAndFreeze nzs2 j2 + let go 0 bs2 _ _ j2 = (bs2 :!:) <$> unsafeShrinkAndFreeze nzs2 j2 go bsTodo bs2 j0 j1 j2 | bs0 .&. b == 0 = do C.write nzs2 j2 $! C.index nzs1 j1 @@ -603,7 +541,7 @@ aPlusU :: (Eq c, Num c, Prim c) => Word64 -> PrimArray c -> Word64 -> P aPlusU bs0 nzs0 bs1 nzs1 = runST $ do let bsAll = bs0 .|. bs1 nzs2 <- C.new (popCount bsAll) - let go 0 bs2 _ _ j2 = (bs2 :!: ) <$> unsafeShrinkAndFreeze nzs2 j2 + let go 0 bs2 _ _ j2 = (bs2 :!:) <$> unsafeShrinkAndFreeze nzs2 j2 go bsTodo bs2 j0 j1 j2 | bs0 .&. b == 0 = do C.write nzs2 j2 $! C.index nzs1 j1 @@ -620,8 +558,6 @@ aPlusU bs0 nzs0 bs1 nzs1 = runST $ do b = lowNzBit bsTodo bsTodo' = bsTodo .^. b go bsAll bsAll 0 0 0 -{-# SPECIALIZE aPlusU :: KnownNat m => Word64 -> PrimArray (Mod m) -> - Word64 -> PrimArray (Mod m) -> Word64 :!: PrimArray (Mod m) #-} {-# INLINABLE aPlusU #-} plusU :: (Eq c, Num c, Prim c) => Op2 (VectorU c) @@ -646,16 +582,184 @@ plusU = go where ~v0 = go (C.index nzts 0) v goGT (SVE {}) _ = undefined -{-# SPECIALIZE plusU :: KnownNat m => Op2 (VectorU (Mod m)) #-} {-# INLINABLE plusU #-} +unionDisj :: (C.Contiguous arr, C.Element arr c, HasCallStack) => Op2 (VectorA arr c) +{- ^ The union of two disjoint vectors. They must have no common indices. \(O(m + n)\), or less + if the vectors have large non-overlapping subtrees. -} +unionDisj = unionWith (const False) (\_ _ -> error "unionDisj: Non-disjoint vectors") +{-# SPECIALIZE unionDisj :: Op2 (Vector c) #-} +{-# INLINEABLE unionDisj #-} + +aVApply :: (C.Contiguous dArr, C.Element dArr d, C.Contiguous cArr, C.Element cArr c) => + (Int -> d -> Op1 (S.Maybe c)) -> Int -> Int -> + Word64 -> dArr d -> Word64 -> cArr c -> Word64 :!: cArr c +aVApply f start iW2 bs0 nzs0 bs1 nzs1 = runST $ do + let bsAll = bs0 .|. bs1 + nzs2 <- C.new (popCount bsAll) + let go 0 bs2 _ _ j2 = (bs2 :!:) <$> unsafeShrinkAndFreeze nzs2 j2 + go bsTodo bs2 j0 j1 j2 + | bs0 .&. b == 0 = do + C.write nzs2 j2 $! C.index nzs1 j1 + go bsTodo' bs2 j0 (j1 + 1) (j2 + 1) + | bs1 .&. b == 0 = goD S.Nothing j1 + | otherwise = goD (S.Just (C.index nzs1 j1)) (j1 + 1) + where + b = lowNzBit bsTodo + bsTodo' = bsTodo .^. b + goD mc1 j1' = + case f !$ start + i0 !<<. iW2 !$ C.index nzs0 j0 !$ mc1 of + S.Nothing -> go bsTodo' (bs2 .^. b) (j0 + 1) j1' j2 + S.Just c -> do + C.write nzs2 j2 c + go bsTodo' bs2 (j0 + 1) j1' (j2 + 1) + where + i0 = fbTruncLog2 b + go bsAll bsAll 0 0 0 +{-# SPECIALIZE aVApply :: (Int -> d -> Op1 (S.Maybe c)) -> Int -> Int + -> Word64 -> SmallArray d -> Word64 -> SmallArray c -> Word64 :!: SmallArray c #-} + +vApply :: (C.Contiguous dArr, C.Element dArr d, C.Contiguous cArr, C.Element cArr c) => + (Int -> d -> Op1 (S.Maybe c)) -> VectorA dArr d -> Op1 (VectorA cArr c) +{- ^ @vApply f ds v@ applies @f i ds[i]@ to @v[i]@, for each non-missing element in @ds@. + Usually \(m + n\) steps, though \(m\) if the input trees have few common nodes, or up to + \(64 (d - log_{64} m) m\) if the first input is very sparse and the second input is very + dense. -} +vApply f ds0 v0 = fromMaybeSv $ go 0 ds0 (toMaybeSv v0) + where + go start ds mv + | S.Nothing <- mv = toMaybeSv $ + mapCMaybeWithIndex (\i d -> f (start + i) d S.Nothing) ds + | isZero ds = mv + | S.Just v <- mv = toMaybeSv $ combineSv (aVApply f start 0) (aVApply go start) ds v +{-# SPECIALIZE vApply :: (Int -> d -> Op1 (S.Maybe c)) -> Vector d -> Op1 (Vector c) #-} +-- @@ elim. vApply, and Maybe functions in this file ? + +sveSelect :: (C.Contiguous arr, C.Element arr c) => Word64 -> Op1 (VectorA arr c) +sveSelect bs1 v0@(SVE bs0 arr0) + | bs == bs0 = v0 + | bs == 0 = zero + | otherwise = runST $ do + a <- C.new (popCount bs) + let go 0 _ = SVE bs <$> C.unsafeFreeze a + go bsTodo j = do + C.write a j $! C.index arr0 (b64Index b bs0) + go bsTodo' (j + 1) + where + b = lowNzBit bsTodo + bsTodo' = bsTodo .^. b + go bs 0 + where + bs = bs0 .&. bs1 +sveSelect _ _ = undefined +{-# INLINE sveSelect #-} + +andNot :: (C.Contiguous arr0, C.Element arr0 c0, C.Contiguous arr1, C.Element arr1 c1) + => VectorA arr0 c0 -> VectorA arr1 c1 -> VectorA arr0 c0 +{- Like a set difference, restrict the first vector to its keys that don't occur in the second + one. \(O(m)\), but usually just 1 step for each subtree in the first vector that isn't in + the second one. -} +andNot v@(SVE {}) (SVE bs1 _) = sveSelect (complement bs1) v +andNot x y + | getIW2 x < getIW2 y = if y.bs .&. 1 /= 0 then andNot x (C.index y.nzts 0) else x + | getIW2 x > getIW2 y = if x.bs .&. 1 == 0 || isZero y then x + else andNot x (SVV 1 x.iW2 (C.singleton y)) + | x.bs .&. y.bs == 0 = x +andNot (SVV bs0 iW2 nzts0) (SVV bs1 _ nzts1) = runST $ do -- same iW2 + nzts2 <- C.new (popCount bs0) + let go 0 bs2 _ j2 = svv bs2 iW2 <$> unsafeShrinkAndFreeze nzts2 j2 + go bsTodo bs2 j0 j2 + | bs1 .&. b == 0 = do + C.write nzts2 j2 $! C.index nzts0 j0 + go bsTodo' bs2 (j0 + 1) (j2 + 1) + | otherwise = do + let t = andNot !$ C.index nzts0 j0 !$ C.index nzts1 (b64Index b bs1) + if isZero t then go bsTodo' (bs2 .^. b) (j0 + 1) j2 else do + C.write nzts2 j2 t + go bsTodo' bs2 (j0 + 1) (j2 + 1) + where + b = lowNzBit bsTodo + bsTodo' = bsTodo .^. b + go bs0 bs0 0 0 +andNot _ _ = undefined +{-# SPECIALIZE andNot :: Vector c0 -> Vector c1 -> Vector c0 #-} + +aFoldLIntersect' :: (C.Contiguous arr0, C.Element arr0 c0, + C.Contiguous arr1, C.Element arr1 c1) => + (t -> c0 -> c1 -> t) -> t -> Word64 -> arr0 c0 -> Word64 -> arr1 c1 -> t +aFoldLIntersect' f t bs0 nzs0 bs1 nzs1 = go t (bs0 .&. bs1) + where + go acc 0 = acc + go acc bsTodo = go acc' bsTodo' + where + b = lowNzBit bsTodo + jF = b64Index b + acc' = f acc !$ C.index nzs0 (jF bs0) !$ C.index nzs1 (jF bs1) + bsTodo' = bsTodo .^. b +{-# SPECIALIZE aFoldLIntersect' :: (t -> c0 -> c1 -> t) -> t -> Word64 -> SmallArray c0 -> + Word64 -> SmallArray c1 -> t #-} + +foldLIntersect' :: (C.Contiguous arr0, C.Element arr0 c0, + C.Contiguous arr1, C.Element arr1 c1) => + (t -> c0 -> c1 -> t) -> t -> VectorA arr0 c0 -> VectorA arr1 c1 -> t +{- ^ Strict left fold over the intersection (common indices) of two vectors. \(O(m + n)\), but + usually just 1 step for each common node or index. -} +foldLIntersect' f = go + where + go t (SVE bs0 nzs0) (SVE bs1 nzs1) = aFoldLIntersect' f t bs0 nzs0 bs1 nzs1 + go t x y + | getIW2 x > getIW2 y = if x.bs .&. 1 /= 0 then go t (C.index x.nzts 0) y else t + | getIW2 x < getIW2 y = if y.bs .&. 1 /= 0 then go t x (C.index y.nzts 0) else t + go t (SVV bs0 _ nzts0) (SVV bs1 _ nzts1) = aFoldLIntersect' go t bs0 nzts0 bs1 nzts1 + go _ _ _ = undefined +{-# SPECIALIZE foldLIntersect' :: (t -> c0 -> c1 -> t) -> t -> Vector c0 -> Vector c1 -> t #-} + +-- * Split/Join + +split :: (C.Contiguous arr, C.Element arr c) => Int -> VectorA arr c -> + VectorA arr c :!: VectorA arr c +{- ^ @split i v@ splits @v@ into parts with indices @\< i@, and @>= i@. Up to \(64 d\) steps + for a dense @v@, or fewer if @i@ is a multiple of a high power of 2. -} +split i = go (max i 0) + where + go k w + | w.bs == lowBs = w :!: zero + | k == 0 = zero :!: w + | SVE bs nzs <- w = SVE lowBs (C.clone (C.slice nzs 0 j)) + :!: SVE (bs .^. lowBs) (C.clone (C.slice nzs j (n1s - j))) + | SVV bs iW2 nzts <- w, + let highBs = bs .&. (negate 2 !<<. k0) + v0 = svv lowBs iW2 (C.clone (C.slice nzts 0 j)) + j1 = if bs .&. b == 0 then j else j + 1 + v3 = svv highBs iW2 (C.clone (C.slice nzts j1 (n1s - j1))) + = if bs .&. b == 0 then v0 :!: v3 else + let v1 :!: v2 = go (k - k0 !<<. iW2) (C.index nzts j) + raise v = if isZero v then v else svv b iW2 (C.singleton v) + in join v0 (raise v1) :!: join (raise v2) v3 + where + k0 = k !>>. getIW2 w + b = if k0 >= 64 then 0 else b64 k0 + lowBs = w.bs .&. (b - 1) + ~j = popCount lowBs + ~n1s = popCount w.bs +{-# SPECIALIZE split :: Int -> Vector c -> Vector c :!: Vector c #-} + +join :: (C.Contiguous arr, C.Element arr c, HasCallStack) => Op2 (VectorA arr c) +{- ^ Concatenate two vectors, e.g. undoing a 'split'. The indices in the first must be smaller + than the indices in the second. Up to \(64 d\) steps for a dense vector, or fewer if the + parts were split at a multiple of a high power of 2. -} +join = unionDisj +{-# INLINE join #-} + +-- * Addition + mkAG :: (C.Contiguous arr, C.Element arr c) => AbelianGroup c -> AbelianGroup (VectorA arr c) -- ^ Addition of vectors takes the same time as 'unionWith'. mkAG ag = AbelianGroup svFlags svEq svPlus zero isZero (mapNzFC ag.neg) where svEq (SVE bs nzs) (SVE bs' nzs') = - bs == bs' && (C.foldrZipWith (\c c' ~b -> ag.eq c c' && b) True nzs nzs') + bs == bs' && C.foldrZipWith (\c c' ~b -> ag.eq c c' && b) True nzs nzs' svEq (SVV bs iW2 nzts) (SVV bs' iW2' nzts') = bs == bs' && iW2 == iW2' && liftEq svEq nzts nzts' svEq _ _ = False @@ -664,6 +768,11 @@ mkAG ag = AbelianGroup svFlags svEq svPlus zero isZero (mapNzFC ag.neg) svFlags = agFlags (IsNontrivial ag.monFlags.nontrivial) {-# SPECIALIZE mkAG :: AbelianGroup c -> AbelianGroup (Vector c) #-} +mkAGU :: (Eq c, Num c, Prim c) => AbelianGroup (VectorU c) +-- ^ 'mkAG' with unboxed coordinates. +mkAGU = (mkAG numAG) { plus = plusU } +{-# INLINABLE mkAGU #-} + -- * Multiplication aDotWith :: (C.Contiguous arr, C.Element arr c, C.Contiguous arr1, C.Element arr1 c1) => @@ -677,8 +786,8 @@ aDotWith tAG f bs0 nzs0 bs1 nzs1 = go tAG.zero (bs0 .&. bs1) go acc bsTodo = go (tAG.plus acc t) bsTodo' where b = lowNzBit bsTodo - j bs = popCount ((b - 1) .&. bs) - t = f !$ C.index nzs0 (j bs0) !$ C.index nzs1 (j bs1) + jF = b64Index b + t = f !$ C.index nzs0 (jF bs0) !$ C.index nzs1 (jF bs1) bsTodo' = bsTodo .^. b {-# SPECIALIZE aDotWith :: AbelianGroup c2 -> (c -> c1 -> c2) -> Word64 -> SmallArray c -> Word64 -> SmallArray c1 -> c2 #-} @@ -692,8 +801,8 @@ dotWith tAG f = go where go (SVE bs0 nzs0) (SVE bs1 nzs1) = aDotWith tAG f bs0 nzs0 bs1 nzs1 go x y - | getIW2 x > getIW2 y = go (C.index x.nzts 0) y - | getIW2 x < getIW2 y = go x (C.index y.nzts 0) + | getIW2 x > getIW2 y = if x.bs .&. 1 /= 0 then go (C.index x.nzts 0) y else tAG.zero + | getIW2 x < getIW2 y = if y.bs .&. 1 /= 0 then go x (C.index y.nzts 0) else tAG.zero go (SVV bs0 _ nzts0) (SVV bs1 _ nzts1) = aDotWith tAG go bs0 nzts0 bs1 nzts1 go _ _ = undefined {-# SPECIALIZE dotWith :: AbelianGroup c2 -> (c -> c1 -> c2) -> Vector c -> Vector c1 -> c2 #-} @@ -725,6 +834,192 @@ monicizeUnit cR@(Ring { times }) i v = in if rIsOne cR c then v else mapNzFC (`times` rInv cR c) v {-# SPECIALIZE monicizeUnit :: Ring c -> Int -> Op1 (Vector c) #-} +aMkModR :: (C.Contiguous arr, C.Element arr c) => + Ring c -> AbelianGroup (VectorA arr c) -> (c -> Op1 (VectorA arr c)) -> + ModR c (VectorA arr c) +aMkModR cR vAG vTimesNzdC = Module vAG scale vBDiv + where + cIsZero = cR.isZero + cNzds = cR.rFlags.noZeroDivisors + timesNzC = if cNzds then vTimesNzdC else timesC cR + scale c v = if cIsZero c then zero else timesNzC c v + minusTimesNzC v q w = vAG.plus v $! timesNzC (cR.neg q) w + vBDiv doFull v w = fromMaybe (cR.zero, v) $ do + i :!: wc <- toLazy $ headPairMaybe w + vc <- if doFull.b then toLazy $ indexMaybe v i else + do vi :!: c <- toLazy $ headPairMaybe v; pureIf (vi == i) c + let (q, _) = cR.bDiv doFull vc wc + pureIf (not (cIsZero q)) (q, minusTimesNzC v q w) +{-# SPECIALIZE aMkModR :: Ring c -> AbelianGroup (Vector c) -> (c -> Op1 (Vector c)) -> + ModR c (Vector c) #-} + +mkModR :: Ring c -> ModR c (Vector c) +-- ^ Make a right module of vectors over a coordinate ring. +mkModR cR = aMkModR cR (mkAG cR.ag) (timesNzdC cR) + +mkModRU :: (Eq c, Num c, Prim c) => Ring c -> ModR c (VectorU c) +-- ^ 'mkModR' with unboxed coordinates. +mkModRU cR = aMkModR cR mkAGU timesNzdCU +{-# INLINABLE mkModRU #-} + +-- * Permute + +{- | A t'Permute' is a permutation, i.e. a bijection from @[0 .. r - 1]@ to @[0 .. r - 1]@ for + some (non-unique) @r@. It is stored sparsely, with fixpoints omitted. -} +data Permute = Permute { to :: Vector Int, from :: Vector Int } + deriving Show; -- ^ e.g. for testing & debugging + +instance Eq Permute where + p == q = p.to == q.to + +pToF :: Permute -> Op1 Int +-- ^ Use a permutation as a function. \(d\) steps per function call. +pToF p i = index i p.to i + +pIdent :: Permute +-- ^ The identity function as a permutation. +pIdent = Permute zero zero + +pSwap :: Int -> Int -> Permute +-- ^ Transpose (swap) two @Int@s. +pSwap i j = if i == j then pIdent else Permute v v + where + ~v = fromDistinctNzs [i :!: j, j :!: i] + +pCycle :: [Int] -> Permute +{- ^ Take each element of the list to the next one, and the last element (if the list is + nonempty) to the first. The elements must be distinct. Other @Int@s are left fixed. + \(O(n log_2 n)\). -} +pCycle [] = pIdent +pCycle [_] = pIdent +pCycle ks@(h : t) = Permute (fromDistinctNzs (S.zip ks rotL)) + (fromDistinctNzs (S.zip rotL ks)) + where + rotL = t ++ [h] + +fToP :: Int -> Op1 Int -> Permute +{- ^ Convert @r@ and a bijection on @[0, 1 .. r - 1]@ to a permutation. Fixpoints (@i@ where + @f i == i@) are compressed (stored compactly). \(O(r log_2 r)\). -} +fToP r f = Permute (fromDistinctAscNzs toL) (fromDistinctNzs (map S.swap toL)) + where + toL = [i :!: e | i <- [0, 1 .. r - 1], let e = f i, e /= i] + +injToPermute :: Vector Int -> Permute +{- ^ Extend an injective (finite) partial function on [0 ..] to a minimal permutation. The newly + defined part of the function will be monotonic. -} +injToPermute to0 = Permute to from + where + keysAndNot = keys .* andNot + to1 = mapCMaybeWithIndex (\i c -> if i /= c then S.Just c else S.Nothing) to0 + from1 = (fromDistinctNzs . map S.swap . toDistinctAscNzs) to1 + is = keysAndNot to1 from1 + js = keysAndNot from1 to1 + to = unionDisj to1 (fromDistinctAscNzs (S.zip js is)) + from = unionDisj from1 (fromDistinctAscNzs (S.zip is js)) + +rComposePTo :: Vector Int -> Vector Int -> Vector Int -> Vector Int +-- Compute (.to) of the right composition of (Permute pTo pFrom) and qTo. +rComposePTo pTo pFrom qTo = unionWith (== -1) const pTo' qTo + where + ic i c = i :!: (if c == i then -1 else c) + movesTwice = fromDistinctNzs (foldLIntersect' (\r i c -> ic i c : r) [] pFrom qTo) + pTo' = unionWith (const False) const movesTwice pTo + +pCompose :: Op2 Permute +-- ^ (left) composition of permutations as functions +pCompose (Permute pTo pFrom) (Permute qTo qFrom) = Permute to from + where + to = rComposePTo qTo qFrom pTo + from = rComposePTo pFrom pTo qFrom + +pGroup :: Group Permute +-- ^ The (infinite) group of permutations under (left) composition. +pGroup = MkMonoid { .. } + where + monFlags = MonoidFlags { nontrivial = True, abelian = False, isGroup = True } + eq = (==) + op = pCompose + ident = pIdent + isIdent p = isZero p.to + inv p = Permute p.from p.to + +permuteV2 :: (C.Contiguous arr, C.Element arr c) => + Vector Int -> VectorA arr c -> [Int :!: c] -> VectorA arr c :!: [Int :!: c] +-- Split @v@ into fixpoints of @to@ and pairs pushed onto @ics@. +permuteV2 (SVE toBs toIs) v@(SVE vBs nzs) ics0 = + if fixBs == vBs then v :!: ics0 else runST $ do + fixNzsMut <- C.new (popCount fixBs) + let go 0 _ _ ics = do + fixNzs <- C.unsafeFreeze fixNzsMut + pure (SVE fixBs fixNzs :!: ics) + go bsTodo vJ fixJ ics + | toBs .&. b == 0 = do + C.write fixNzsMut fixJ c + go bsTodo' (vJ + 1) (fixJ + 1) ics + | let toJ = b64Index b toBs + = go bsTodo' (vJ + 1) fixJ ((C.index toIs toJ :!: c) : ics) + where + b = lowNzBit bsTodo + bsTodo' = bsTodo .^. b + c = C.index nzs vJ + go vBs 0 0 ics0 + where + fixBs = vBs .&. complement toBs +permuteV2 (SVV bs iW2 nzts) v ics | iW2 > getIW2 v = + if bs .&. 1 == 0 then v :!: ics else permuteV2 (C.index nzts 0) v ics +permuteV2 to v@(SVV _ iW2 _) ics | getIW2 to < iW2 = + permuteV2 (SVV 1 iW2 (C.singleton to)) v ics +permuteV2 (SVV toBs iW2 toNzts) v@(SVV vBs _ vNzts) ics0 = + if fixBs0 == vBs then v :!: ics0 else runST $ do + fixNztsMut <- C.new (popCount vBs) + let go 0 fixBs _ fixJ ics = do + fixNzts <- unsafeShrinkAndFreeze fixNztsMut fixJ + pure (svv fixBs iW2 fixNzts :!: ics) + go bsTodo fixBs vJ fixJ ics + | toBs .&. b == 0 = do + C.write fixNztsMut fixJ vNzt + go bsTodo' fixBs vJ' (fixJ + 1) ics + | let toJ = b64Index b toBs + v1 :!: ics' = permuteV2 (C.index toNzts toJ) vNzt ics + = if isZero v1 then go bsTodo' fixBs vJ' fixJ ics' else do + C.write fixNztsMut fixJ v1 + go bsTodo' (fixBs .|. b) vJ' (fixJ + 1) ics' + where + b = lowNzBit bsTodo + bsTodo' = bsTodo .^. b + vNzt = C.index vNzts vJ + vJ' = vJ + 1 + go vBs fixBs0 0 0 ics0 + where + fixBs0 = vBs .&. complement toBs +permuteV2 _ _ _ = undefined +{-# SPECIALIZE permuteV2 :: Vector Int -> Vector c -> [Int :!: c] -> Vector c :!: [Int :!: c] + #-} +{-# INLINEABLE permuteV2 #-} + +permuteV :: (C.Contiguous arr, C.Element arr c) => Permute -> Op1 (VectorA arr c) +{- Permute the coordinates of a vector. For a free module @R^n@, this is the linear map where + the permutation acts on the standard basis elements. -} +permuteV p v = unionDisj v1 (fromDistinctNzs ics) + where + v1 :!: ics = permuteV2 p.to v [] +{-# SPECIALIZE permuteV :: Permute -> Op1 (Vector c) #-} +{-# INLINEABLE permuteV #-} + +swap :: (C.Contiguous arr, C.Element arr c) => Int -> Int -> Op1 (VectorA arr c) +-- ^ swap two coordinates. Up to \(128 d\) steps for a very dense vector. +swap i j = permuteV (pSwap i j) +{-# SPECIALIZE swap :: Int -> Int -> Op1 (Vector c) #-} + +sortPermute :: (C.Contiguous arr, C.Element arr c) => + Cmp c -> VectorA arr c -> (VectorA arr c, Permute) +{- ^ Sort a vector, and return a permutation on its coordinates to get the sorted result. + \(O(n log n)\). -} +sortPermute cmp v = (fromNzs cs, pGroup.inv (injToPermute (fromNzs is))) + where + (is, cs) = S.unzip (sortLBy (cmp `on` S.snd) (toDistinctAscNzs v)) +{-# SPECIALIZE sortPermute :: Cmp c -> Vector c -> (Vector c, Permute) #-} + -- * I/O showPrec :: (C.Contiguous arr, C.Element arr c) => diff --git a/test/Math/Algebra/General/TestAlgebra.hs b/test/Math/Algebra/General/TestAlgebra.hs index c4665cc..a2e5f40 100644 --- a/test/Math/Algebra/General/TestAlgebra.hs +++ b/test/Math/Algebra/General/TestAlgebra.hs @@ -62,7 +62,7 @@ module Math.Algebra.General.TestAlgebra ( pairTestOps, -- * List - listTestOps, listTestEq, allTM, isSortedBy, slTestOps, + listTestOps, listTestEq, allTM, slTestOps, -- * Parse parseTest, @@ -76,6 +76,7 @@ module Math.Algebra.General.TestAlgebra ( import Math.Algebra.General.Algebra hiding (assert) +-- import Debug.Trace.Text (traceEvent) import Hedgehog (Gen, Property, PropertyT, Range, (===), annotate, assert, cover, diff, discard, failure, forAllWith, property, withDiscards, withTests) @@ -87,6 +88,7 @@ import Test.Tasty.Hedgehog (testProperty) import GHC.Records import Control.Monad (join, unless, when) +-- import Control.Monad.IO.Class (liftIO) import Data.Functor.Classes (liftEq, liftEq2) import Data.Maybe (fromMaybe) import Data.Strict.Tuple ((:!:), pattern (:!:)) @@ -555,16 +557,6 @@ allTM aSP p as = do annotateB $ listF qs failure --- | The 'isSortedBy' function returns 'True' iff the predicate returns true --- for all adjacent pairs of elements in the list. -isSortedBy :: (a -> a -> Bool) -> [a] -> Bool --- from Data.List.Ordered -isSortedBy lte = loop - where - loop [] = True - loop [_] = True - loop (x:y:zs) = (x `lte` y) && loop (y:zs) - slTestOps :: Range Int -> TestOps a -> TestOps (SL.List a) -- ^ t'TestOps' for a strict list, given a 'Hedgehog.Range.Range' for the length diff --git a/test/Math/Algebra/Linear/TestSparseVector.hs b/test/Math/Algebra/Linear/TestSparseVector.hs index 37a9bac..29197d9 100644 --- a/test/Math/Algebra/Linear/TestSparseVector.hs +++ b/test/Math/Algebra/Linear/TestSparseVector.hs @@ -18,12 +18,13 @@ import qualified Hedgehog.Range as Range import Control.Monad (when) import Data.Bifunctor (second) import qualified Data.IntMap.Strict as IM +import Data.Maybe (fromMaybe) +import qualified Data.Primitive.Contiguous as C import Data.Strict.Classes (toLazy, toStrict) import qualified Data.Strict.Maybe as S import qualified Data.Strict.Tuple as S import Data.Strict.Tuple ((:!:), pattern (:!:)) import qualified Data.Text as T -import Safe (headMay, lastMay) -- @@@ or use IntMap lookupMin, lookupMax and toStrict, remove 'safe' dependency sJustIf :: Bool -> a -> S.Maybe a @@ -35,8 +36,9 @@ nToText26 :: Integral n => n -> Text nToText26 n = T.singleton (toEnum (fromEnum 'a' + fromIntegral n `mod` 26)) -testOpsAG :: AbelianGroup c -> Range Int -> TestOps Int -> TestOps c -> - TestOps (SV.Vector c) +testOpsAG :: (C.Contiguous arr, C.Element arr c, Show (SV.VectorA arr c)) => + AbelianGroup c -> + Range Int -> TestOps Int -> TestOps c -> TestOps (SV.VectorA arr c) -- ^ @testOpsAG cAG sumRange iTA cTA@. The caller tests @cAG@. testOpsAG cAG sumRange iTA cTA = TestOps tSP tCheck gen vAG.eq where @@ -45,49 +47,56 @@ testOpsAG cAG sumRange iTA cTA = TestOps tSP tCheck gen vAG.eq SV.iFoldL (\_i _ c -> cTA.tCheck notes1 c) (pure ()) v -- show i on failure? tCheckBool (errs ++ notes1) (null errs) where - notes1 = (tSP v).t : notes + notes1 = {- (tSP v).t -} showT v : notes errs = SV.check (const cAG.isZero) v vAG = SV.mkAG cAG iCToV = SV.fromPIC cAG.isZero gen = sumL' vAG <$> Gen.list sumRange (liftA2 iCToV iTA.gen cTA.gen) -type V = SV.Vector Integer -- the main type for testing SparseVector.hs -type Y = Int -- V maps almost injectively to Y -type VL = [Int :!: Integer] -- only DistinctAscNzs; V->VL->V == id, so VL->V is a - -- surjection --- type IM = IM.IntMap -- only nonzero terms; V->IM->V == id +type C = Int +type V = SV.Vector C -- the main type for testing SparseVector.hs +type H = Int -- V maps almost injectively to H +type VL = [Int :!: C] -- only DistinctAscNzs; V->VL->V == id, so VL->V is a surjection +type IM = IM.IntMap C -- only nonzero terms; V->IM->V == id +type VU = SV.VectorU C -- isomorphic to V as right modules over C tests :: TestTree -- ^ Test the "Math.Algebra.Linear.SparseVector" module. tests = testGroup "SparseVector" testsL where - cAG = zzAG + cAG = intRing.ag largeInts = take 6 $ iterate (`quot` 2) (maxBound :: Int) - iTA = numVarTestOps "u" (Gen.frequency + iTA = numVarTestOps "u" (Gen.frequency -- index test ops [(10, Gen.int (Range.exponential 0 1_000_000)), (1, Gen.element largeInts)]) - cTA = zzTestOps { gen = zzExpGen 200 } - vTA = testOpsAG cAG (Range.linear 0 20) iTA cTA - vAG = SV.mkAG cAG - vToY :: V -> Y - iCToY i c = (3 * i `rem` 101 + 5) * fromIntegral c - vToY = SV.foldBIMap' (+) 0 iCToY + cTA = testOps0 (Gen.int (Range.exponentialFrom 0 minBound maxBound)) + vTA = testOpsAG cAG (Range.linear 0 20) iTA cTA :: TestOps V + vAG = SV.mkAG cAG :: AbelianGroup V + vToH :: V -> H -- a homomorphism of right modules over C + iCToH i c = (3 * i `rem` 101 + 5) * c + vToH = SV.foldBIMap' (+) 0 iCToH vToIM = IM.fromDistinctAscList . map toLazy . SV.toDistinctAscNzs imNzsToV = SV.fromDistinctAscNzs . map toStrict . IM.toAscList + vToVU = SV.mapNzFC id :: V -> VU + vuToV = SV.mapNzFC id :: VU -> V - testViaY :: V -> Y -> TestM () -- tAnnotate v, and check it maps to y - testViaY = tImageEq vTA (===) vToY + testViaH :: V -> H -> TestM () -- tAnnotate v, and check it maps to h + testViaH = tImageEq vTA (===) vToH testViaL :: TestRel b -> (V -> b) -> (VL -> b) -> TestM () -- test the (V -> b) testViaL bTestEq f okF = sameFun1TR vTA bTestEq f (okF . SV.toDistinctAscNzs) testEqToVL :: V -> VL -> TestM () testEqToVL w okDANzs = vTA.tEq w (SV.fromDistinctAscNzs okDANzs) + testViaIM :: TestRel b -> (V -> b) -> (IM -> b) -> TestM () -- test the (V -> b) + testViaIM bTestEq f okF = sameFun1TR vTA bTestEq f (okF . vToIM) + + sJustNz c = sJustIf (c /= 0) c fromICTest = singleTest "fromIC" $ do -- test fromPIC, fromIMaybeC, fromNzIC i <- genVis iTA c <- genVis cTA - when (c /= 0) $ testViaY (SV.fromNzIC i c) (iCToY i c) - testViaY (SV.fromPIC cAG.isZero i c) (iCToY i c) - testViaY (SV.fromIMaybeC i (sJustIf (c /= 0) c)) (iCToY i c) + when (c /= 0) $ testViaH (SV.fromNzIC i c) (iCToH i c) + testViaH (SV.fromPIC cAG.isZero i c) (iCToH i c) + testViaH (SV.fromIMaybeC i (sJustNz c)) (iCToH i c) distinctAscNzsTest = singleTest "distinctAscNzs" $ do -- test toDistinctAscNzs, fromDistinctAscNzs v <- genVis vTA @@ -97,12 +106,13 @@ tests = testGroup "SparseVector" testsL tCheckBool ["Not strictly ascending"] $ isSortedBy ((<) `on` S.fst) vL vTA.tEq (sumL' vAG (map (S.uncurry SV.fromNzIC) vL)) v vTA.tEq (SV.fromDistinctAscNzs vL) v + -- TODO: test fromDistinctNzs indexTest = singleTest "index, indexMaybe, split, join" $ do v <- genVis vTA i <- genVis (testOps0 (Gen.int (Range.exponential 0 10_000))) let vL = SV.toDistinctAscNzs v mc = lookup i (map toLazy vL) - SV.index 0 v i === maybe 0 id mc + SV.index 0 v i === fromMaybe 0 mc SV.indexMaybe v i === toStrict mc let (v0 :!: v1) = SV.split i v testEqToVL v0 (filter ((< i) . S.fst) vL) @@ -115,8 +125,8 @@ tests = testGroup "SparseVector" testsL SV.lastPair v === S.fromJust (SV.lastPairMaybe v) foldsTest = testGroup "folds" [iFoldRTest, iFoldLTest, foldBIMap'Test] where - iNNToN :: Int -> Op2 Integer - iNNToN i m n = 2 * fromIntegral i + m - n + iNNToN :: Int -> Op2 C + iNNToN i m n = 2 * i + m - n iFoldRTest = singleTest "iFoldR" $ testViaL (===) (SV.iFoldR iNNToN 100) (foldr (S.uncurry iNNToN) 100) iFoldLTest = singleTest "iFoldL" $ @@ -130,9 +140,7 @@ tests = testGroup "SparseVector" testsL testEqToVL (SV.mapC (== 0) (`rem` 3) v) (filter ((/= 0) . S.snd) (map (second (`rem` 3)) vL)) testEqToVL (SV.mapNzFC (3 *) v) (map (second (3 *)) vL) - let iCToMC i c = sJustIf (c' /= 0) c' - where - c' = (fromIntegral i + c) `rem` 3 + let iCToMC i c = sJustNz ((i + c) `rem` 3) pToMP (i :!: c) = (i :!:) <$> iCToMC i c testEqToVL (SV.mapCMaybeWithIndex iCToMC v) (S.catMaybes (map pToMP vL)) unionTest = singleTest "unionWith" $ do @@ -143,21 +151,66 @@ tests = testGroup "SparseVector" testsL wIM = vToIM w vTA.tEq (SV.unionWith (== 0) f v w) (imNzsToV (IM.filter (/= 0) (IM.unionWith f vIM wIM))) - -- @@@ incl. tCheck/tAnnotate intermediate values from here on: - -- vApply, invPermute, swap; 4 mult; showPrec? - -- :@@@ + plusUTest = singleTest "plusU" $ + sameFunAABTR vTA vTA.tEq (vuToV .* SV.plusU `on` vToVU) vAG.plus + -- TODO: test unionDisj + vApplyTest = singleTest "vApply" $ sameFunAABTR vTA vTA.tEq (SV.vApply iDMcToMc) okF + where + iDMcToC i d mc = (i + 2 * d + 3 * S.fromMaybe 5 mc) `rem` 7 + iDMcToMc i d = sJustNz . iDMcToC i d + iDCToMc i d = toLazy . iDMcToMc i d . S.Just + dIMToCIM = IM.filter (/= 0) . IM.mapWithKey (\i d -> iDMcToC i d S.Nothing) + mergeIM = IM.mergeWithKey iDCToMc dIMToCIM id + okF = imNzsToV .* mergeIM `on` vToIM + -- TODO: test foldLIntersect' + -- TODO: test mkAGU + dotWithTest = singleTest "dotWith" $ sameFunAABTR vTA (===) (SV.dotWith cAG mult) okF + where + mult c d = c * (d `quot` 2) -- noncommutative + okF v = SV.foldBIMap' (+) 0 (\i c -> SV.index 0 v i `mult` c) + timesCsTest = singleTest "timesNzdC, timesNzdCU, timesC" $ do + v <- genVis vTA + c <- genVis cTA + let vcH = vToH v * c + when (odd c) $ do + testViaH (SV.timesNzdC intRing c v) vcH + testViaH (vuToV (SV.timesNzdCU c (vToVU v))) vcH + testViaH (SV.timesC intRing c v) vcH + -- TODO: test over a noncommutative ring also, using a right-linear map for vToH + -- TODO: test monicizeUnit, e.g. using Z mod p or Q + -- TODO: test mkModR, mkModRU + {- TODO: test Permute section + vPComposeTest = singleTest "vPCompose" $ do + v <- genVis vTA + pv <- SV.mapNzFC abs1 <$> genVis vTA + vTA.tEq (SV.vPCompose v (SV.OrId pv)) ((imNzsToV .* vPComposeIM `on` vToIM) v pv) + where + abs1 n = if n == minBound then 1 else abs n + vPComposeIM vm pm = IM.compose vm (IM.union pm (IM.mapWithKey (\i _ -> i) vm)) -} + swapTest = singleTest "swap" $ do + v <- genVis vTA + i <- genVis iTA + j <- genVis iTA + let swapIJ k = if k == i then j else if k == j then i else k + vTA.tEq (SV.swap i j v) (imNzsToV (IM.mapKeys swapIJ (vToIM v))) + -- TODO: test showPrec testsL = [ - singleTest "not eq, vToY" $ almostInjectiveTM vTA (==) vToY, - -- testViaY is now valid - singleTest "plus" $ homomTM vTA vAG.plus (===) (+) vToY, + singleTest "not eq, vToH" $ almostInjectiveTM vTA (==) vToH, + -- testViaH is now valid + singleTest "plus" $ homomTM vTA vAG.plus (===) (+) vToH, abelianGroupTests vTA (IsNontrivial True) vAG, -- vAG has now been tested (by the above lines), so vTA.gen and vTA.eq have been also testOnce "zero" $ vTA.tEq SV.zero vAG.zero, fromICTest, distinctAscNzsTest, - -- testViaL, testEqToVL, vToIM, and imToV are now valid + -- testViaL, testEqToVL, vToIM, imToV, vToIM, imNzsToV, and testViaIM are now valid singleTest "isZero" $ testViaL (===) SV.isZero null, singleTest "size" $ testViaL (===) SV.size length, - singleTest "headPairMaybe" $ testViaL (===) SV.headPairMaybe (toStrict . headMay), - singleTest "lastPairMaybe" $ testViaL (===) SV.lastPairMaybe (toStrict . lastMay), - indexTest, headLastTest, foldsTest, mapsTest, unionTest + singleTest "headPairMaybe" $ + testViaIM (===) SV.headPairMaybe (toStrict . (toStrict <$>) . IM.lookupMin), + singleTest "lastPairMaybe" $ + testViaIM (===) SV.lastPairMaybe (toStrict . (toStrict <$>) . IM.lookupMax), + indexTest, headLastTest, foldsTest, mapsTest, + -- vToVU and vuToV are now valid + unionTest, plusUTest, vApplyTest, dotWithTest, timesCsTest, + swapTest ] diff --git a/timings/bench.hs b/timings/bench.hs index 3c4685f..8055aeb 100644 --- a/timings/bench.hs +++ b/timings/bench.hs @@ -88,13 +88,6 @@ benchesStrictList = benchWhnf (force . map (+ 1)) (showSize "force map") numsLL] -intRing :: Ring Int -intRing = numRing rFlags (const intDiv) - where - rFlags = RingFlags { commutative = True, noZeroDivisors = False, nzInverses = False } - intDiv y 0 = (0, y) - intDiv y m = quotRem y m - divDeep' :: Ring r -> (r -> r -> S.Pair r r) divDeep' rR = toStrict .* rR.bDiv (IsDeep True) diff --git a/timings/time-gb.hs b/timings/time-gb.hs index 59ab0fa..d9c7818 100644 --- a/timings/time-gb.hs +++ b/timings/time-gb.hs @@ -29,8 +29,8 @@ import Data.List (unfoldr) -} main :: IO () main = do -{- let rands = take 300000 $ unfoldr (Just . uniform) (mkStdGen 137) :: [Int] - print $ sum $ sortByPar 100 compare rands -} +{- let rands = take 300_000 $ unfoldr (Just . uniform) (mkStdGen 137) :: [Int] + print $ sum $ sortByPar 100_000 compare rands -} nCores <- getNumCapabilities args <- getArgs