Skip to content

Commit 17127e0

Browse files
authored
Add 'fromListWithKey' to HashMap (#246)
* Add 'fromListWithKey' to HashMap Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Define `unsafeInsertWith` using `unsafeInsertWithKey` See #246 (comment) for performance numbers. * Improve documentation of fromListWithKey Thanks to @emilypi for the suggestions * Add a test for fromListWithKey It models the test for fromListWith but makes sure that values are combined in a way that depends on the key. * Clarify the documentation for fromListWithKey * Improve properties for fromListWith and fromListWithKey The old properties used associative operators to combine values when there were duplicate keys. With this diff we're using a non-commutative and non-associative operator which can catch more bugs. * Make sure to return an unboxed tuple in unsafeFromListWithKey This got lost in a rebase * Use the free magma to ensure that we preserve the order of operations * Update fromListWithKey documentation to use non-commutative, non-associative operators Since the combining function is applied in a way that can be counter-intuitive it's more pedagical to to use operators which better illustrate this behaviour.
1 parent 0952034 commit 17127e0

File tree

5 files changed

+102
-11
lines changed

5 files changed

+102
-11
lines changed

Data/HashMap/Base.hs

+39-4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ module Data.HashMap.Base
8282
, toList
8383
, fromList
8484
, fromListWith
85+
, fromListWithKey
8586

8687
-- Internals used by the strict version
8788
, Hash
@@ -1031,14 +1032,20 @@ insertModifyingArr x f k0 ary0 = go k0 ary0 0 (A.length ary0)
10311032
unsafeInsertWith :: forall k v. (Eq k, Hashable k)
10321033
=> (v -> v -> v) -> k -> v -> HashMap k v
10331034
-> HashMap k v
1034-
unsafeInsertWith f k0 v0 m0 = runST (go h0 k0 v0 0 m0)
1035+
unsafeInsertWith f k0 v0 m0 = unsafeInsertWithKey (const f) k0 v0 m0
1036+
{-# INLINABLE unsafeInsertWith #-}
1037+
1038+
unsafeInsertWithKey :: forall k v. (Eq k, Hashable k)
1039+
=> (k -> v -> v -> v) -> k -> v -> HashMap k v
1040+
-> HashMap k v
1041+
unsafeInsertWithKey f k0 v0 m0 = runST (go h0 k0 v0 0 m0)
10351042
where
10361043
h0 = hash k0
10371044
go :: Hash -> k -> v -> Shift -> HashMap k v -> ST s (HashMap k v)
10381045
go !h !k x !_ Empty = return $! Leaf h (L k x)
10391046
go h k x s t@(Leaf hy l@(L ky y))
10401047
| hy == h = if ky == k
1041-
then return $! Leaf h (L k (f x y))
1048+
then return $! Leaf h (L k (f k x y))
10421049
else return $! collision h l (L k x)
10431050
| otherwise = two s h k x hy t
10441051
go h k x s t@(BitmapIndexed b ary)
@@ -1059,9 +1066,9 @@ unsafeInsertWith f k0 v0 m0 = runST (go h0 k0 v0 0 m0)
10591066
return t
10601067
where i = index h s
10611068
go h k x s t@(Collision hy v)
1062-
| h == hy = return $! Collision h (updateOrSnocWith (\a b -> (# f a b #)) k x v)
1069+
| h == hy = return $! Collision h (updateOrSnocWithKey (\key a b -> (# f key a b #) ) k x v)
10631070
| otherwise = go h k x s $ BitmapIndexed (mask hy s) (A.singleton t)
1064-
{-# INLINABLE unsafeInsertWith #-}
1071+
{-# INLINABLE unsafeInsertWithKey #-}
10651072

10661073
-- | /O(log n)/ Remove the mapping for the specified key from this map
10671074
-- if present.
@@ -1916,6 +1923,34 @@ fromListWith :: (Eq k, Hashable k) => (v -> v -> v) -> [(k, v)] -> HashMap k v
19161923
fromListWith f = L.foldl' (\ m (k, v) -> unsafeInsertWith f k v m) empty
19171924
{-# INLINE fromListWith #-}
19181925

1926+
-- | /O(n*log n)/ Construct a map from a list of elements. Uses
1927+
-- the provided function to merge duplicate entries.
1928+
--
1929+
-- === Examples
1930+
--
1931+
-- Given a list of key-value pairs where the keys are of different flavours, e.g:
1932+
--
1933+
-- > data Key = Div | Sub
1934+
--
1935+
-- and the values need to be combined differently when there are duplicates,
1936+
-- depending on the key:
1937+
--
1938+
-- > combine Div = div
1939+
-- > combine Sub = (-)
1940+
--
1941+
-- then @fromListWithKey@ can be used as follows:
1942+
--
1943+
-- > fromListWithKey combine [(Div, 2), (Div, 6), (Sub, 2), (Sub, 3)]
1944+
-- > = fromList [(Div, 3), (Sub, 1)]
1945+
--
1946+
-- More generally, duplicate entries are accumulated as follows;
1947+
--
1948+
-- > fromListWith f [(k, a), (k, b), (k, c), (k, d)]
1949+
-- > = fromList [(k, f k d (f k c (f k b a)))]
1950+
fromListWithKey :: (Eq k, Hashable k) => (k -> v -> v -> v) -> [(k, v)] -> HashMap k v
1951+
fromListWithKey f = L.foldl' (\ m (k, v) -> unsafeInsertWithKey f k v m) empty
1952+
{-# INLINE fromListWithKey #-}
1953+
19191954
------------------------------------------------------------------------
19201955
-- Array operations
19211956

Data/HashMap/Lazy.hs

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ module Data.HashMap.Lazy
9494
, toList
9595
, fromList
9696
, fromListWith
97+
, fromListWithKey
9798

9899
-- ** HashSets
99100
, HS.keysSet

Data/HashMap/Strict.hs

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ module Data.HashMap.Strict
9393
, toList
9494
, fromList
9595
, fromListWith
96+
, fromListWithKey
9697

9798
-- ** HashSets
9899
, HS.keysSet

Data/HashMap/Strict/Base.hs

+40-5
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ module Data.HashMap.Strict.Base
9595
, toList
9696
, fromList
9797
, fromListWith
98+
, fromListWithKey
9899
) where
99100

100101
import Data.Bits ((.&.), (.|.))
@@ -109,7 +110,8 @@ import Prelude hiding (map, lookup)
109110
import qualified Data.HashMap.Array as A
110111
import qualified Data.HashMap.Base as HM
111112
import Data.HashMap.Base hiding (
112-
alter, alterF, adjust, fromList, fromListWith, insert, insertWith,
113+
alter, alterF, adjust, fromList, fromListWith, fromListWithKey,
114+
insert, insertWith,
113115
differenceWith, intersectionWith, intersectionWithKey, map, mapWithKey,
114116
mapMaybe, mapMaybeWithKey, singleton, update, unionWith, unionWithKey,
115117
traverseWithKey)
@@ -189,13 +191,18 @@ insertWith f k0 v0 m0 = go h0 k0 v0 0 m0
189191
-- | In-place update version of insertWith
190192
unsafeInsertWith :: (Eq k, Hashable k) => (v -> v -> v) -> k -> v -> HashMap k v
191193
-> HashMap k v
192-
unsafeInsertWith f k0 v0 m0 = runST (go h0 k0 v0 0 m0)
194+
unsafeInsertWith f k0 v0 m0 = unsafeInsertWithKey (const f) k0 v0 m0
195+
{-# INLINABLE unsafeInsertWith #-}
196+
197+
unsafeInsertWithKey :: (Eq k, Hashable k) => (k -> v -> v -> v) -> k -> v -> HashMap k v
198+
-> HashMap k v
199+
unsafeInsertWithKey f k0 v0 m0 = runST (go h0 k0 v0 0 m0)
193200
where
194201
h0 = hash k0
195202
go !h !k x !_ Empty = return $! leaf h k x
196203
go h k x s t@(Leaf hy l@(L ky y))
197204
| hy == h = if ky == k
198-
then return $! leaf h k (f x y)
205+
then return $! leaf h k (f k x y)
199206
else do
200207
let l' = x `seq` (L k x)
201208
return $! collision h l l'
@@ -218,9 +225,9 @@ unsafeInsertWith f k0 v0 m0 = runST (go h0 k0 v0 0 m0)
218225
return t
219226
where i = index h s
220227
go h k x s t@(Collision hy v)
221-
| h == hy = return $! Collision h (updateOrSnocWith f k x v)
228+
| h == hy = return $! Collision h (updateOrSnocWithKey f k x v)
222229
| otherwise = go h k x s $ BitmapIndexed (mask hy s) (A.singleton t)
223-
{-# INLINABLE unsafeInsertWith #-}
230+
{-# INLINABLE unsafeInsertWithKey #-}
224231

225232
-- | /O(log n)/ Adjust the value tied to a given key in this map only
226233
-- if it is present. Otherwise, leave the map alone.
@@ -639,6 +646,34 @@ fromListWith :: (Eq k, Hashable k) => (v -> v -> v) -> [(k, v)] -> HashMap k v
639646
fromListWith f = L.foldl' (\ m (k, v) -> unsafeInsertWith f k v m) empty
640647
{-# INLINE fromListWith #-}
641648

649+
-- | /O(n*log n)/ Construct a map from a list of elements. Uses
650+
-- the provided function to merge duplicate entries.
651+
--
652+
-- === Examples
653+
--
654+
-- Given a list of key-value pairs where the keys are of different flavours, e.g:
655+
--
656+
-- > data Key = Div | Sub
657+
--
658+
-- and the values need to be combined differently when there are duplicates,
659+
-- depending on the key:
660+
--
661+
-- > combine Div = div
662+
-- > combine Sub = (-)
663+
--
664+
-- then @fromListWithKey@ can be used as follows:
665+
--
666+
-- > fromListWithKey combine [(Div, 2), (Div, 6), (Sub, 2), (Sub, 3)]
667+
-- > = fromList [(Div, 3), (Sub, 1)]
668+
--
669+
-- More generally, duplicate entries are accumulated as follows;
670+
--
671+
-- > fromListWith f [(k, a), (k, b), (k, c), (k, d)]
672+
-- > = fromList [(k, f k d (f k c (f k b a)))]
673+
fromListWithKey :: (Eq k, Hashable k) => (k -> v -> v -> v) -> [(k, v)] -> HashMap k v
674+
fromListWithKey f = L.foldl' (\ m (k, v) -> unsafeInsertWithKey f k v m) empty
675+
{-# INLINE fromListWithKey #-}
676+
642677
------------------------------------------------------------------------
643678
-- Array operations
644679

tests/HashMapProperties.hs

+21-2
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,31 @@ pFilterWithKey = M.filterWithKey p `eq_` HM.filterWithKey p
363363
------------------------------------------------------------------------
364364
-- ** Conversions
365365

366+
-- The free magma is used to test that operations are applied in the
367+
-- same order.
368+
data Magma a
369+
= Leaf a
370+
| Op (Magma a) (Magma a)
371+
deriving (Show, Eq, Ord)
372+
373+
instance Hashable a => Hashable (Magma a) where
374+
hashWithSalt s (Leaf a) = hashWithSalt s (hashWithSalt (1::Int) a)
375+
hashWithSalt s (Op m n) = hashWithSalt s (hashWithSalt (hashWithSalt (2::Int) m) n)
376+
366377
-- 'eq_' already calls fromList.
367378
pFromList :: [(Key, Int)] -> Bool
368379
pFromList = id `eq_` id
369380

370381
pFromListWith :: [(Key, Int)] -> Bool
371-
pFromListWith kvs = (M.toAscList $ M.fromListWith (+) kvs) ==
372-
(toAscList $ HM.fromListWith (+) kvs)
382+
pFromListWith kvs = (M.toAscList $ M.fromListWith Op kvsM) ==
383+
(toAscList $ HM.fromListWith Op kvsM)
384+
where kvsM = fmap (fmap Leaf) kvs
385+
386+
pFromListWithKey :: [(Key, Int)] -> Bool
387+
pFromListWithKey kvs = (M.toAscList $ M.fromListWithKey combine kvsM) ==
388+
(toAscList $ HM.fromListWithKey combine kvsM)
389+
where kvsM = fmap (\(K k,v) -> (Leaf k, Leaf v)) kvs
390+
combine k v1 v2 = Op k (Op v1 v2)
373391

374392
pToList :: [(Key, Int)] -> Bool
375393
pToList = M.toAscList `eq` toAscList
@@ -467,6 +485,7 @@ tests =
467485
, testProperty "keys" pKeys
468486
, testProperty "fromList" pFromList
469487
, testProperty "fromListWith" pFromListWith
488+
, testProperty "fromListWithKey" pFromListWithKey
470489
, testProperty "toList" pToList
471490
]
472491
]

0 commit comments

Comments
 (0)