From 11c21a50b065c8bf750b1216f51b59739792e345 Mon Sep 17 00:00:00 2001 From: Alec Theriault Date: Fri, 10 Aug 2018 17:38:28 -0700 Subject: [PATCH] Check 'fromDistinctAscList' invariants Check that the keys of Set/Map/InstSet/IntMap really are ascending when deserializing. This requires adding an extra 'Ord' constraint to the 'Set' and 'Map' instances. --- src/Data/Binary/Class.hs | 48 +++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/src/Data/Binary/Class.hs b/src/Data/Binary/Class.hs index 2eed93e0..699b1eee 100644 --- a/src/Data/Binary/Class.hs +++ b/src/Data/Binary/Class.hs @@ -638,22 +638,54 @@ instance Binary BS.ShortByteString where ------------------------------------------------------------------------ -- Maps and Sets -instance (Binary a) => Binary (Set.Set a) where +instance (Ord a, Binary a) => Binary (Set.Set a) where put s = put (Set.size s) <> mapM_ put (Set.toAscList s) - get = liftM Set.fromDistinctAscList get - -instance (Binary k, Binary e) => Binary (Map.Map k e) where + get = do ascList <- get + case ascList of + [] -> pure Set.empty + x:xs -> do guardAsc x xs + pure (Set.fromDistinctAscList ascList) + where guardAsc _ [] = pure () + guardAsc x (y:xs) + | x < y = guardAsc y xs + | otherwise = fail "Set values are not ascending!" + +instance (Ord k, Binary k, Binary e) => Binary (Map.Map k e) where put m = put (Map.size m) <> mapM_ put (Map.toAscList m) - get = liftM Map.fromDistinctAscList get + get = do ascList <- get + case ascList of + [] -> pure Map.empty + (k,_):kvs -> do guardAsc k kvs + pure (Map.fromDistinctAscList ascList) + where guardAsc _ [] = pure () + guardAsc k ((j,_):kvs) + | k < j = guardAsc j kvs + | otherwise = fail "Map keys are not ascending!" instance Binary IntSet.IntSet where put s = put (IntSet.size s) <> mapM_ put (IntSet.toAscList s) - get = liftM IntSet.fromDistinctAscList get + get = do ascList <- get + case ascList of + [] -> pure IntSet.empty + i:is -> do guardAsc i is + pure (IntSet.fromDistinctAscList ascList) + where guardAsc _ [] = pure () + guardAsc i (j:is) + | i < j = guardAsc j is + | otherwise = fail "IntSet values are not ascending!" instance (Binary e) => Binary (IntMap.IntMap e) where put m = put (IntMap.size m) <> mapM_ put (IntMap.toAscList m) - get = liftM IntMap.fromDistinctAscList get - + get = do ascList <- get + case ascList of + [] -> pure IntMap.empty + (i,_):ivs -> do guardAsc i ivs + pure (IntMap.fromDistinctAscList ascList) + where guardAsc _ [] = pure () + guardAsc i ((j,_):ivs) + | i < j = guardAsc j ivs + | otherwise = fail "IntMap keys are not ascending!" + ------------------------------------------------------------------------ -- Queues and Sequences