Skip to content

Commit

Permalink
Update and export set intersection utilities (#1040)
Browse files Browse the repository at this point in the history
* Update Set.intersections to be lazier
* Mark definitions INLINABLE for specialization
* Add matching Intersection and intersections for IntSet
* Add property tests
  • Loading branch information
meooow25 authored Sep 22, 2024
1 parent 2a109ad commit dc98ae7
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 11 deletions.
24 changes: 24 additions & 0 deletions containers-tests/tests/intset-properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import Data.List (nub,sort)
import qualified Data.List as List
import Data.Maybe (listToMaybe)
import Data.Monoid (mempty)
#if MIN_VERSION_base(4,18,0)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Foldable1 as Foldable1
#endif
import qualified Data.Set as Set
import IntSetValidity (valid)
import Prelude hiding (lookup, null, map, filter, foldr, foldl, foldl')
Expand Down Expand Up @@ -82,6 +87,10 @@ main = defaultMain $ testGroup "intset-properties"
, testProperty "prop_bitcount" prop_bitcount
, testProperty "prop_alterF_list" prop_alterF_list
, testProperty "prop_alterF_const" prop_alterF_const
#if MIN_VERSION_base(4,18,0)
, testProperty "intersections" prop_intersections
, testProperty "intersections_lazy" prop_intersections_lazy
#endif
]

----------------------------------------------------------------
Expand Down Expand Up @@ -500,3 +509,18 @@ prop_alterF_const
prop_alterF_const f k s =
getConst (alterF (Const . applyFun f) k s )
=== getConst (Set.alterF (Const . applyFun f) k (toSet s))

#if MIN_VERSION_base(4,18,0)
prop_intersections :: (IntSet, [IntSet]) -> Property
prop_intersections (s, ss) =
intersections ss' === Foldable1.foldl1' intersection ss'
where
ss' = s :| ss -- Work around missing Arbitrary NonEmpty instance

prop_intersections_lazy :: [IntSet] -> Property
prop_intersections_lazy ss = intersections ss' === empty
where
ss' = NE.fromList $ ss ++ [empty] ++ undefined
-- ^ result will certainly be empty at this point,
-- so the rest of the list should not be demanded.
#endif
24 changes: 24 additions & 0 deletions containers-tests/tests/set-properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import Control.Monad (liftM, liftM3)
import Data.Functor.Identity
import Data.Foldable (all)
import Control.Applicative (liftA2)
#if MIN_VERSION_base(4,18,0)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Foldable1 as Foldable1
#endif

#if __GLASGOW_HASKELL__ >= 806
import Utils.NoThunks (whnfHasNoThunks)
Expand Down Expand Up @@ -112,6 +117,10 @@ main = defaultMain $ testGroup "set-properties"
#endif
, testProperty "eq" prop_eq
, testProperty "compare" prop_compare
#if MIN_VERSION_base(4,18,0)
, testProperty "intersections" prop_intersections
, testProperty "intersections_lazy" prop_intersections_lazy
#endif
]

-- A type with a peculiar Eq instance designed to make sure keys
Expand Down Expand Up @@ -738,3 +747,18 @@ prop_eq s1 s2 = (s1 == s2) === (toList s1 == toList s2)

prop_compare :: Set Int -> Set Int -> Property
prop_compare s1 s2 = compare s1 s2 === compare (toList s1) (toList s2)

#if MIN_VERSION_base(4,18,0)
prop_intersections :: (Set Int, [Set Int]) -> Property
prop_intersections (s, ss) =
intersections ss' === Foldable1.foldl1' intersection ss'
where
ss' = s :| ss -- Work around missing Arbitrary NonEmpty instance

prop_intersections_lazy :: [Set Int] -> Property
prop_intersections_lazy ss = intersections ss' === empty
where
ss' = NE.fromList $ ss ++ [empty] ++ undefined
-- ^ result will certainly be empty at this point,
-- so the rest of the list should not be demanded.
#endif
3 changes: 3 additions & 0 deletions containers/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@

* Add `lookupMin` and `lookupMax` for `Data.IntSet`. (Soumik Sarkar)

* Add `Intersection` and `intersections` for `Data.Set` and `Data.IntSet`.
(Reed Mullanix, Soumik Sarkar)

## Unreleased with `@since` annotation for 0.7.1:

### Additions
Expand Down
4 changes: 4 additions & 0 deletions containers/src/Data/IntSet.hs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ module Data.IntSet (
, difference
, (\\)
, intersection
#if MIN_VERSION_base(4,18,0)
, intersections
#endif
, symmetricDifference
, Intersection(..)

-- * Filter
, IS.filter
Expand Down
46 changes: 44 additions & 2 deletions containers/src/Data/IntSet/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ module Data.IntSet.Internal (
, unions
, difference
, intersection
#if MIN_VERSION_base(4,18,0)
, intersections
#endif
, symmetricDifference
, Intersection(..)

-- * Filter
, filter
Expand Down Expand Up @@ -192,11 +196,15 @@ import Control.DeepSeq (NFData(rnf))
import Data.Bits
import qualified Data.List as List
import Data.Maybe (fromMaybe)
import Data.Semigroup (Semigroup(stimes))
import Data.Semigroup
(Semigroup(stimes), stimesIdempotent, stimesIdempotentMonoid)
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup (Semigroup((<>)))
#endif
import Data.Semigroup (stimesIdempotentMonoid)
#if MIN_VERSION_base(4,18,0)
import qualified Data.Foldable1 as Foldable1
import Data.List.NonEmpty (NonEmpty(..))
#endif
import Utils.Containers.Internal.Prelude hiding
(filter, foldr, foldl, foldl', null, map)
import Prelude ()
Expand Down Expand Up @@ -659,6 +667,40 @@ intersection (Tip kx1 bm1) t2 = intersectBM t2

intersection Nil _ = Nil

#if MIN_VERSION_base(4,18,0)
-- | The intersection of a series of sets. Intersections are performed
-- left-to-right.
--
-- @since FIXME
intersections :: Foldable1.Foldable1 f => f IntSet -> IntSet
intersections ss = case Foldable1.toNonEmpty ss of
s0 :| ss'
| null s0 -> empty
| otherwise -> List.foldr go id ss' s0
where
go s r acc
| null acc' = empty
| otherwise = r acc'
where
acc' = intersection acc s
{-# INLINABLE intersections #-}
#endif

-- | @IntSet@s form a 'Semigroup' under 'intersection'.
--
-- A @Monoid@ instance is not defined because it would be impractical to
-- construct @mempty@, the @IntSet@ containing all @Int@s.
--
-- @since FIXME
newtype Intersection = Intersection { getIntersection :: IntSet }
deriving (Show, Eq, Ord)

instance Semigroup Intersection where
Intersection s1 <> Intersection s2 = Intersection (intersection s1 s2)

stimes = stimesIdempotent
{-# INLINABLE stimes #-}

{--------------------------------------------------------------------
Symmetric difference
--------------------------------------------------------------------}
Expand Down
4 changes: 4 additions & 0 deletions containers/src/Data/Set.hs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,13 @@ module Data.Set (
, difference
, (\\)
, intersection
#if MIN_VERSION_base(4,18,0)
, intersections
#endif
, symmetricDifference
, cartesianProduct
, disjointUnion
, Intersection(..)

-- * Filter
, S.filter
Expand Down
39 changes: 30 additions & 9 deletions containers/src/Data/Set/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ module Data.Set.Internal (
, unions
, difference
, intersection
#if MIN_VERSION_base(4,18,0)
, intersections
#endif
, symmetricDifference
, cartesianProduct
, disjointUnion
Expand Down Expand Up @@ -240,7 +242,6 @@ import Control.Applicative (Const(..))
import qualified Data.List as List
import Data.Bits (shiftL, shiftR)
import Data.Semigroup (Semigroup(stimes))
import Data.List.NonEmpty (NonEmpty(..))
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup (Semigroup((<>)))
#endif
Expand All @@ -249,6 +250,10 @@ import Data.Functor.Classes
import Data.Functor.Identity (Identity)
import qualified Data.Foldable as Foldable
import Control.DeepSeq (NFData(rnf))
#if MIN_VERSION_base(4,18,0)
import qualified Data.Foldable1 as Foldable1
import Data.List.NonEmpty (NonEmpty(..))
#endif

import Utils.Containers.Internal.StrictPair
import Utils.Containers.Internal.PtrEquality
Expand Down Expand Up @@ -894,21 +899,37 @@ intersection t1@(Bin _ x l1 r1) t2
{-# INLINABLE intersection #-}
#endif

-- | The intersection of a series of sets. Intersections are performed left-to-right.
intersections :: Ord a => NonEmpty (Set a) -> Set a
intersections (s0 :| ss) = List.foldr go id ss s0
where
go s r acc
| null acc = empty
| otherwise = r (intersection acc s)
#if MIN_VERSION_base(4,18,0)
-- | The intersection of a series of sets. Intersections are performed
-- left-to-right.
--
-- @since FIXME
intersections :: (Foldable1.Foldable1 f, Ord a) => f (Set a) -> Set a
intersections ss = case Foldable1.toNonEmpty ss of
s0 :| ss'
| null s0 -> empty
| otherwise -> List.foldr go id ss' s0
where
go s r acc
| null acc' = empty
| otherwise = r acc'
where
acc' = intersection acc s
{-# INLINABLE intersections #-}
#endif

-- | Sets form a 'Semigroup' under 'intersection'.
-- | @Set@s form a 'Semigroup' under 'intersection'.
--
-- @since FIXME
newtype Intersection a = Intersection { getIntersection :: Set a }
deriving (Show, Eq, Ord)

instance (Ord a) => Semigroup (Intersection a) where
(Intersection a) <> (Intersection b) = Intersection $ intersection a b
{-# INLINABLE (<>) #-}

stimes = stimesIdempotent
{-# INLINABLE stimes #-}

{--------------------------------------------------------------------
Symmetric difference
Expand Down

0 comments on commit dc98ae7

Please sign in to comment.