Skip to content

Commit

Permalink
Add SV.Permute, a semi-sparse permutation
Browse files Browse the repository at this point in the history
Also more work on VectorA & MatrixA, and use samsort.
  • Loading branch information
DaveBarton committed Aug 19, 2024
1 parent 6d63cd6 commit fe8063c
Show file tree
Hide file tree
Showing 13 changed files with 665 additions and 287 deletions.
11 changes: 6 additions & 5 deletions calculi.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ common deps
build-depends:
base >= 4.17 && < 5.0,
containers >= 0.6 && < 0.8,
contiguous >= 0.6.1.1 && < 0.7,
extra >= 1.6 && < 1.8,
fmt >= 0.4 && < 0.7,
ghc-trace-events < 0.2,
megaparsec < 9.7,
mod >= 0.1.2 && < 0.3,
parser-combinators < 1.4,
safe < 0.4,
-- safe < 0.4,
strict >= 0.4 && < 0.6,
text < 2.2

Expand All @@ -66,14 +68,13 @@ library

build-depends:
async < 2.3,
contiguous >= 0.6.1.1 && < 0.7,
ghc-trace-events < 0.2,
-- hashable, unordered-containers,
-- inspection-testing >= 0.3 && < 0.6,
-- poly >= 0.5.1, deepseq, finite-typelits, vector-sized,
primitive >= 0.8 && < 0.10,
random >= 1.2.1 && < 1.3,
rrb-vector >= 0.2.1 && < 0.3,
samsort < 0.2,
stm >= 2.4 && < 2.6,
strict-list < 0.2,
time >= 1.9.3 && < 2,
Expand Down Expand Up @@ -111,7 +112,7 @@ test-suite calculi-test
Math.Algebra.Commutative.TestEPoly
Math.Algebra.Commutative.TestBinPoly

build-depends: calculi, hedgehog == 1.*, mod, rrb-vector, strict-list, tasty < 1.6,
build-depends: calculi, hedgehog == 1.*, rrb-vector, strict-list, tasty < 1.6,
tasty-hedgehog < 1.5

hs-source-dirs: test
Expand All @@ -135,7 +136,7 @@ benchmark bench
import: deps
main-is: bench.hs
type: exitcode-stdio-1.0
build-depends: calculi, deepseq < 2, random >= 1.2 && < 1.3, tasty, tasty-bench < 0.4
build-depends: calculi, deepseq < 2, random >= 1.2 && < 1.3, tasty, tasty-bench < 0.5
-- , poly, vector, vector-sized
hs-source-dirs: timings
ghc-options: -rtsopts -with-rtsopts=-T
31 changes: 26 additions & 5 deletions src/Control/Parallel/Cooperative.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
If you can guarantee your threads never block during a task, you can allocate them
one-per-core (really one-per-capability) using 'parNonBlocking', for a small performance
speedup. Otherwise, it's better to have some extra threads, and let them migrate between
cores, using 'parThreads'.
cores, using 'parThreads'. In either case, we don't create and destroy threads as often as
sparks do, which seems to speed things up.
This module uses "Control.Concurrent.Async" to ensure that uncaught exceptions are
propagated to parent threads, and orphaned threads are always killed.
Expand All @@ -44,7 +45,7 @@ module Control.Parallel.Cooperative (
vecMapParChunk, rrbMapParChunk,

-- * Folding and sorting
foldBalanced, foldBalancedPar, sortByPar,
foldBalanced, foldBalancedPar, sortLBy, -- sortByPar,

-- * Utilities
seqSpine, seqElts, inc1TVar, maybeStateTVar, popTVar,
Expand All @@ -54,14 +55,18 @@ module Control.Parallel.Cooperative (

import Control.Monad (when)
import Control.Monad.Extra (ifM, unlessM, whileM)
import Control.Monad.ST (runST)
import Data.Bits (Bits, FiniteBits, (.&.), bit, countLeadingZeros, finiteBitSize, xor)
import Data.Foldable (toList)
import qualified Data.IntMap.Strict as IMS
import Data.List (sortBy, uncons, unfoldr)
import Data.List.Extra (chunksOf, mergeBy)
import Data.List (uncons, unfoldr)
import Data.List.Extra (chunksOf)
-- import Data.Maybe (isJust)
import Data.Maybe (isNothing)
import Data.Primitive.Array (pattern MutableArray)
import qualified Data.Primitive.Contiguous as C
import qualified Data.RRBVector as RRB
import Data.SamSort (sortArrayBy)
import qualified Data.Sequence as Seq
import qualified Data.Vector as V
import GHC.Stack (HasCallStack)
Expand Down Expand Up @@ -315,9 +320,25 @@ foldBalancedPar f as = if null (drop 3 as) then foldBalanced f as else
buddy k = (2 * lowNzBit k) `xor` k
top = bit (fbTruncLog2 n)

sortLBy :: (a -> a -> Ordering) -> [a] -> [a]
{- ^ Like 'Data.List.sortBy', but faster, though not lazy (it always sorts the entire list).
'sortLBy' currently uses [samsort](https://hackage.haskell.org/package/samsort), so it's
stable and adaptive. Also like 'Data.SamSort.sortArrayBy', 'sortLBy' will inline to get the
best performance out of statically known comparison functions. To avoid code duplication,
create a wrapping definition and reuse it as necessary. -}
sortLBy cmp xs = runST $ do
ma@(MutableArray ma') <- C.fromListMutable xs
n <- C.sizeMut ma
sortArrayBy cmp ma' 0 n
C.toListMutable ma
{-# INLINE sortLBy #-}

{- @@@ faster to just use sortLBy (samsort)
sortByPar :: Int -> (a -> a -> Ordering) -> [a] -> [a]
{- ^ Strict stable sort by sorting chunks in parallel. The chunk size must be positive, and 100
appears to be a good value. The spine of the result is forced. -}
sortByPar chunkSize cmp as = if null as then [] else
forkJoinPar (chunksOf chunkSize) (foldBalancedPar (\es -> seqSpine . mergeBy cmp es))
(seqSpine . sortBy cmp) as
(seqSpine . sortLBy cmp) as
-- @@@ INLINE sortByPar also, doc
-}
3 changes: 1 addition & 2 deletions src/Math/Algebra/Commutative/BinPoly.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import Control.Monad.Extra (pureIf)
import Data.Bits ((.&.), (.|.), bit, complement, countLeadingZeros, popCount, testBit,
unsafeShiftL, unsafeShiftR)
import Data.Foldable (toList)
import Data.List (sortBy)
import Data.Maybe (catMaybes, fromJust)
import Data.Word (Word64)
import StrictList2 (pattern (:!))
Expand Down Expand Up @@ -115,7 +114,7 @@ data BPOtherOps ev vals = BPOtherOps {
}

bpSortCancel :: Cmp ev -> SL.List ev -> BinPoly ev
bpSortCancel evCmp evs = cancelRev (sortBy evCmp (SL.toListReversed evs)) SL.Nil
bpSortCancel evCmp evs = cancelRev (sortLBy evCmp (SL.toListReversed evs)) SL.Nil
where
cancelRev (v : t1@(w : ~t2)) r
| evCmp v w == EQ = cancelRev t2 r
Expand Down
10 changes: 3 additions & 7 deletions src/Math/Algebra/Commutative/Field/ZModPW.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ zzModPW = (field numAG (*) 1 fromInteger recip, balRep)
in toInteger (fromIntegral (if u > maxBalRep then u - p else u) :: Int)


toWord :: Mod m -> Word
-- ^ The result is @\< @m@. @@ move to Data.Mod.Word
toWord = unsafeCoerce

unsafeFromWord :: Word -> Mod m
-- ^ The argument must be @\< m@. @@ move to Data.Mod.Word
unsafeFromWord = unsafeCoerce
Expand All @@ -42,7 +38,7 @@ newtype ModWord32 (m :: Nat) = ModW32 { w32 :: Word32 {- ^ @w32 \< m@ -} }
deriving newtype (Eq, Show, Prim)

modWToW32 :: Mod m -> ModWord32 m
modWToW32 = ModW32 . fromIntegral . toWord
modWToW32 = ModW32 . fromIntegral . unMod

modW32ToW :: ModWord32 m -> Mod m
modW32ToW = unsafeFromWord . fromIntegral . (.w32)
Expand All @@ -69,7 +65,7 @@ newtype ModWord16 (m :: Nat) = ModW16 { w16 :: Word16 {- ^ @w16 \< m@ -} }
deriving newtype (Eq, Show, Prim)

modWToW16 :: Mod m -> ModWord16 m
modWToW16 = ModW16 . fromIntegral . toWord
modWToW16 = ModW16 . fromIntegral . unMod

modW16ToW :: ModWord16 m -> Mod m
modW16ToW = unsafeFromWord . fromIntegral . (.w16)
Expand All @@ -96,7 +92,7 @@ newtype ModWord8 (m :: Nat) = ModW8 { w8 :: Word8 {- ^ @w8 \< m@ -} }
deriving newtype (Eq, Show, Prim)

modWToW8 :: Mod m -> ModWord8 m
modWToW8 = ModW8 . fromIntegral . toWord
modWToW8 = ModW8 . fromIntegral . unMod

modW8ToW :: ModWord8 m -> Mod m
modW8ToW = unsafeFromWord . fromIntegral . (.w8)
Expand Down
6 changes: 3 additions & 3 deletions src/Math/Algebra/Commutative/GroebnerBasis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Control.Monad.Extra (ifM, orM, whenM)
import Data.Bits ((.&.), (.|.))
import Data.Foldable (find, minimumBy, toList)
import Data.Int (Int64)
import Data.List (elemIndex, findIndices, groupBy, sortBy)
import Data.List (elemIndex, findIndices, groupBy)
import Data.List.Extra (chunksOf)
import Data.Maybe (catMaybes, fromJust, isJust, isNothing, listToMaybe, mapMaybe)
import qualified Data.RRBVector as GBV
Expand Down Expand Up @@ -250,7 +250,7 @@ updatePairs (GBPolyOps { nVars, evCmp, extraSPairs, useSugar }) gMGis ijcs tGi
itcs = catMaybes (zipWith (\i -> fmap (giToSp i t)) [0..] itMGis) :: [SPair ev]
-- "sloppy" sugar method:
itcss = TS.measurePure "1.1.2 sort/group new itcs" $ seqElts $
groupBy (cmpEq lcmCmp) (sortBy lcmCmp itcs)
groupBy (cmpEq lcmCmp) (sortLBy lcmCmp itcs)
itcsToC = (.m) . head
itcss' = TS.measurePure "1.1.3 M_ik" $ seqElts $ filter (noDivs . itcsToC) itcss
where -- criterion M_ik; 3 seqElts calls for TS.measurePure
Expand Down Expand Up @@ -713,7 +713,7 @@ groebnerBasis gbpA@(GBPolyOps { .. }) gbTrace gbi0 newGens = TS.scope $ do
pure True
doSP = maybe (pure False) newG =<< setPop ijcsRef
mapM_ (\g -> addGenHN =<< reduce_n (EPolyHDeg g (homogDeg0 g)))
(sortBy (evCmp `on` leadEvNz) (filter (not . pIsZero) newGens))
(sortLBy (evCmp `on` leadEvNz) (filter (not . pIsZero) newGens))
numSleepingVar <- newTVarIO (0 :: Int)
let traceTime = do
cpuTime2 <- getCPUTime
Expand Down
25 changes: 23 additions & 2 deletions src/Math/Algebra/General/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ module Math.Algebra.General.Algebra (
Cmp,
cmpEq,
maxBy, minBy,
isSortedBy,

-- * Monoids and Groups
-- $monoids
Expand Down Expand Up @@ -96,7 +97,7 @@ module Math.Algebra.General.Algebra (
-- * Basic numeric rings
numAG, numRing,
-- ** Integer
zzAG, zzDiv, zzRing,
zzAG, zzDiv, zzRing, intRing,
-- ** Double
dblAG, dblRing,

Expand All @@ -123,7 +124,7 @@ module Math.Algebra.General.Algebra (
#if ! MIN_VERSION_base(4, 18, 0)
liftA2,
#endif
assert
assert, sortLBy
) where

import GHC.Records
Expand All @@ -132,6 +133,7 @@ import GHC.Records
import Control.Applicative (liftA2) -- unnecesary in base 4.18+, since in Prelude
#endif
import Control.Exception (assert)
import Control.Parallel.Cooperative (sortLBy)
import Data.Bifunctor (bimap, second)
import Data.Char (isDigit)
#if ! MIN_VERSION_base(4, 20, 0)
Expand Down Expand Up @@ -243,6 +245,16 @@ minBy :: Cmp a -> Op2 a
-- ^ > minBy cmp x y = if cmp x y /= GT then x else y
minBy cmp x y = if cmp x y /= GT then x else y

-- | The 'isSortedBy' function returns 'True' iff the predicate returns true
-- for all adjacent pairs of elements in the list.
isSortedBy :: (a -> a -> Bool) -> [a] -> Bool
-- from Data.List.Ordered in data-ordlist
isSortedBy lte = loop
where
loop [] = True
loop [_] = True
loop (x:y:zs) = (x `lte` y) && loop (y:zs)


-- * Monoids and Groups

Expand Down Expand Up @@ -642,6 +654,15 @@ zzRing :: Ring Integer
-- ^ the ring of integers ℤ
zzRing = numRing integralDomainFlags zzDiv

intRing :: Ring Int
{- ^ Arithmetic mod @2^n@, where an @Int@ has @n@ bits. Division is simply @Int@ division. This
is mostly used for testing and benchmarks. -}
intRing = numRing rFlags (const intDiv)
where
rFlags = RingFlags { commutative = True, noZeroDivisors = False, nzInverses = False }
intDiv y 0 = (0, y)
intDiv y m = quotRem y m

-- ** Double

dblAG :: AbelianGroup Double
Expand Down
3 changes: 1 addition & 2 deletions src/Math/Algebra/General/SparseSum.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import Math.Algebra.Category.Category

import Data.Bifunctor (Bifunctor(first, second, bimap))
import Data.Foldable (toList)
import Data.List (sortBy)
import Data.Maybe (fromJust)
import GHC.Stack (HasCallStack)
import StrictList2 (pattern (:!))
Expand Down Expand Up @@ -157,7 +156,7 @@ ssAGUniv (AbelianGroup cFlags eq plus _zero isZero neg) dCmp =
ssFoldSort :: AbelianGroup c -> Cmp d -> [SSTerm c d] -> SparseSum c d
-- ^ sorts and combines the terms; input terms may have coefs which are 0
ssFoldSort (AbelianGroup _ _ cPlus _ cIsZero _) dCmp cds0 =
go ssZero (sortBy (dCmp `on` (.d)) cds0)
go ssZero (sortLBy (dCmp `on` (.d)) cds0)
where
go done [] = done
go done (cd : t) = go1 done cd.d cd.c t
Expand Down
Loading

0 comments on commit fe8063c

Please sign in to comment.