diff --git a/containers-tests/tests/intset-properties.hs b/containers-tests/tests/intset-properties.hs index c93d6e8af..201a27104 100644 --- a/containers-tests/tests/intset-properties.hs +++ b/containers-tests/tests/intset-properties.hs @@ -7,6 +7,11 @@ import Data.List (nub,sort) import qualified Data.List as List import Data.Maybe (listToMaybe) import Data.Monoid (mempty) +#if MIN_VERSION_base(4,18,0) +import Data.List.NonEmpty (NonEmpty(..)) +import qualified Data.List.NonEmpty as NE +import qualified Data.Foldable1 as Foldable1 +#endif import qualified Data.Set as Set import IntSetValidity (valid) import Prelude hiding (lookup, null, map, filter, foldr, foldl, foldl') @@ -82,6 +87,10 @@ main = defaultMain $ testGroup "intset-properties" , testProperty "prop_bitcount" prop_bitcount , testProperty "prop_alterF_list" prop_alterF_list , testProperty "prop_alterF_const" prop_alterF_const +#if MIN_VERSION_base(4,18,0) + , testProperty "intersections" prop_intersections + , testProperty "intersections_lazy" prop_intersections_lazy +#endif ] ---------------------------------------------------------------- @@ -500,3 +509,18 @@ prop_alterF_const prop_alterF_const f k s = getConst (alterF (Const . applyFun f) k s ) === getConst (Set.alterF (Const . applyFun f) k (toSet s)) + +#if MIN_VERSION_base(4,18,0) +prop_intersections :: (IntSet, [IntSet]) -> Property +prop_intersections (s, ss) = + intersections ss' === Foldable1.foldl1' intersection ss' + where + ss' = s :| ss -- Work around missing Arbitrary NonEmpty instance + +prop_intersections_lazy :: [IntSet] -> Property +prop_intersections_lazy ss = intersections ss' === empty + where + ss' = NE.fromList $ ss ++ [empty] ++ undefined + -- ^ result will certainly be empty at this point, + -- so the rest of the list should not be demanded. +#endif diff --git a/containers-tests/tests/set-properties.hs b/containers-tests/tests/set-properties.hs index 82d30f0c6..90a3e2d5d 100644 --- a/containers-tests/tests/set-properties.hs +++ b/containers-tests/tests/set-properties.hs @@ -16,6 +16,11 @@ import Control.Monad (liftM, liftM3) import Data.Functor.Identity import Data.Foldable (all) import Control.Applicative (liftA2) +#if MIN_VERSION_base(4,18,0) +import Data.List.NonEmpty (NonEmpty(..)) +import qualified Data.List.NonEmpty as NE +import qualified Data.Foldable1 as Foldable1 +#endif #if __GLASGOW_HASKELL__ >= 806 import Utils.NoThunks (whnfHasNoThunks) @@ -112,6 +117,10 @@ main = defaultMain $ testGroup "set-properties" #endif , testProperty "eq" prop_eq , testProperty "compare" prop_compare +#if MIN_VERSION_base(4,18,0) + , testProperty "intersections" prop_intersections + , testProperty "intersections_lazy" prop_intersections_lazy +#endif ] -- A type with a peculiar Eq instance designed to make sure keys @@ -738,3 +747,18 @@ prop_eq s1 s2 = (s1 == s2) === (toList s1 == toList s2) prop_compare :: Set Int -> Set Int -> Property prop_compare s1 s2 = compare s1 s2 === compare (toList s1) (toList s2) + +#if MIN_VERSION_base(4,18,0) +prop_intersections :: (Set Int, [Set Int]) -> Property +prop_intersections (s, ss) = + intersections ss' === Foldable1.foldl1' intersection ss' + where + ss' = s :| ss -- Work around missing Arbitrary NonEmpty instance + +prop_intersections_lazy :: [Set Int] -> Property +prop_intersections_lazy ss = intersections ss' === empty + where + ss' = NE.fromList $ ss ++ [empty] ++ undefined + -- ^ result will certainly be empty at this point, + -- so the rest of the list should not be demanded. +#endif diff --git a/containers/changelog.md b/containers/changelog.md index c0e00aa83..d650886fc 100644 --- a/containers/changelog.md +++ b/containers/changelog.md @@ -39,6 +39,9 @@ * Add `lookupMin` and `lookupMax` for `Data.IntSet`. (Soumik Sarkar) +* Add `Intersection` and `intersections` for `Data.Set` and `Data.IntSet`. + (Reed Mullanix, Soumik Sarkar) + ## Unreleased with `@since` annotation for 0.7.1: ### Additions diff --git a/containers/src/Data/IntSet.hs b/containers/src/Data/IntSet.hs index 9d497e91e..83d555f2d 100644 --- a/containers/src/Data/IntSet.hs +++ b/containers/src/Data/IntSet.hs @@ -106,7 +106,11 @@ module Data.IntSet ( , difference , (\\) , intersection +#if MIN_VERSION_base(4,18,0) + , intersections +#endif , symmetricDifference + , Intersection(..) -- * Filter , IS.filter diff --git a/containers/src/Data/IntSet/Internal.hs b/containers/src/Data/IntSet/Internal.hs index 0af00e4aa..f8afaed64 100644 --- a/containers/src/Data/IntSet/Internal.hs +++ b/containers/src/Data/IntSet/Internal.hs @@ -125,7 +125,11 @@ module Data.IntSet.Internal ( , unions , difference , intersection +#if MIN_VERSION_base(4,18,0) + , intersections +#endif , symmetricDifference + , Intersection(..) -- * Filter , filter @@ -192,11 +196,15 @@ import Control.DeepSeq (NFData(rnf)) import Data.Bits import qualified Data.List as List import Data.Maybe (fromMaybe) -import Data.Semigroup (Semigroup(stimes)) +import Data.Semigroup + (Semigroup(stimes), stimesIdempotent, stimesIdempotentMonoid) #if !(MIN_VERSION_base(4,11,0)) import Data.Semigroup (Semigroup((<>))) #endif -import Data.Semigroup (stimesIdempotentMonoid) +#if MIN_VERSION_base(4,18,0) +import qualified Data.Foldable1 as Foldable1 +import Data.List.NonEmpty (NonEmpty(..)) +#endif import Utils.Containers.Internal.Prelude hiding (filter, foldr, foldl, foldl', null, map) import Prelude () @@ -659,6 +667,40 @@ intersection (Tip kx1 bm1) t2 = intersectBM t2 intersection Nil _ = Nil +#if MIN_VERSION_base(4,18,0) +-- | The intersection of a series of sets. Intersections are performed +-- left-to-right. +-- +-- @since FIXME +intersections :: Foldable1.Foldable1 f => f IntSet -> IntSet +intersections ss = case Foldable1.toNonEmpty ss of + s0 :| ss' + | null s0 -> empty + | otherwise -> List.foldr go id ss' s0 + where + go s r acc + | null acc' = empty + | otherwise = r acc' + where + acc' = intersection acc s +{-# INLINABLE intersections #-} +#endif + +-- | @IntSet@s form a 'Semigroup' under 'intersection'. +-- +-- A @Monoid@ instance is not defined because it would be impractical to +-- construct @mempty@, the @IntSet@ containing all @Int@s. +-- +-- @since FIXME +newtype Intersection = Intersection { getIntersection :: IntSet } + deriving (Show, Eq, Ord) + +instance Semigroup Intersection where + Intersection s1 <> Intersection s2 = Intersection (intersection s1 s2) + + stimes = stimesIdempotent + {-# INLINABLE stimes #-} + {-------------------------------------------------------------------- Symmetric difference --------------------------------------------------------------------} diff --git a/containers/src/Data/Set.hs b/containers/src/Data/Set.hs index ddfa9f02b..030b4e009 100644 --- a/containers/src/Data/Set.hs +++ b/containers/src/Data/Set.hs @@ -111,9 +111,13 @@ module Data.Set ( , difference , (\\) , intersection +#if MIN_VERSION_base(4,18,0) + , intersections +#endif , symmetricDifference , cartesianProduct , disjointUnion + , Intersection(..) -- * Filter , S.filter diff --git a/containers/src/Data/Set/Internal.hs b/containers/src/Data/Set/Internal.hs index 805f5c361..9e310dda7 100644 --- a/containers/src/Data/Set/Internal.hs +++ b/containers/src/Data/Set/Internal.hs @@ -155,7 +155,9 @@ module Data.Set.Internal ( , unions , difference , intersection +#if MIN_VERSION_base(4,18,0) , intersections +#endif , symmetricDifference , cartesianProduct , disjointUnion @@ -240,7 +242,6 @@ import Control.Applicative (Const(..)) import qualified Data.List as List import Data.Bits (shiftL, shiftR) import Data.Semigroup (Semigroup(stimes)) -import Data.List.NonEmpty (NonEmpty(..)) #if !(MIN_VERSION_base(4,11,0)) import Data.Semigroup (Semigroup((<>))) #endif @@ -249,6 +250,10 @@ import Data.Functor.Classes import Data.Functor.Identity (Identity) import qualified Data.Foldable as Foldable import Control.DeepSeq (NFData(rnf)) +#if MIN_VERSION_base(4,18,0) +import qualified Data.Foldable1 as Foldable1 +import Data.List.NonEmpty (NonEmpty(..)) +#endif import Utils.Containers.Internal.StrictPair import Utils.Containers.Internal.PtrEquality @@ -894,21 +899,37 @@ intersection t1@(Bin _ x l1 r1) t2 {-# INLINABLE intersection #-} #endif --- | The intersection of a series of sets. Intersections are performed left-to-right. -intersections :: Ord a => NonEmpty (Set a) -> Set a -intersections (s0 :| ss) = List.foldr go id ss s0 - where - go s r acc - | null acc = empty - | otherwise = r (intersection acc s) +#if MIN_VERSION_base(4,18,0) +-- | The intersection of a series of sets. Intersections are performed +-- left-to-right. +-- +-- @since FIXME +intersections :: (Foldable1.Foldable1 f, Ord a) => f (Set a) -> Set a +intersections ss = case Foldable1.toNonEmpty ss of + s0 :| ss' + | null s0 -> empty + | otherwise -> List.foldr go id ss' s0 + where + go s r acc + | null acc' = empty + | otherwise = r acc' + where + acc' = intersection acc s +{-# INLINABLE intersections #-} +#endif --- | Sets form a 'Semigroup' under 'intersection'. +-- | @Set@s form a 'Semigroup' under 'intersection'. +-- +-- @since FIXME newtype Intersection a = Intersection { getIntersection :: Set a } deriving (Show, Eq, Ord) instance (Ord a) => Semigroup (Intersection a) where (Intersection a) <> (Intersection b) = Intersection $ intersection a b + {-# INLINABLE (<>) #-} + stimes = stimesIdempotent + {-# INLINABLE stimes #-} {-------------------------------------------------------------------- Symmetric difference