Skip to content

Commit

Permalink
Protect against corruption due to duplicated streaming
Browse files Browse the repository at this point in the history
This should address #4 by using the same treatment as used for `zlib` in

ndmitchell/zlib@15cb310
  • Loading branch information
hvr committed Oct 23, 2016
1 parent 86683d5 commit 42ca5f5
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions src/Codec/Compression/Lzma.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ import Control.Exception
import Control.Monad
import Control.Monad.ST (stToIO)
import Control.Monad.ST.Lazy (ST, runST, strictToLazyST)
import qualified Control.Monad.ST.Strict as ST.Strict (ST)
import Control.Monad.ST.Unsafe (unsafeIOToST)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Lazy.Internal as BSL
import GHC.IO (noDuplicate)
import LibLzma

-- | Decompress lazy 'ByteString' from the @.xz@ format
Expand Down Expand Up @@ -199,7 +202,8 @@ compressIO parms = (stToIO $ newEncodeLzmaStream parms) >>= either throwIO go

-- | Incremental compression in the lazy 'ST' monad.
compressST :: CompressParams -> ST s (CompressStream (ST s))
compressST parms = strictToLazyST (newEncodeLzmaStream parms) >>= either throw go
compressST parms = strictToLazyST (newEncodeLzmaStream parms) >>=
either throw go
where
bUFSIZ = 32752

Expand All @@ -209,7 +213,8 @@ compressST parms = strictToLazyST (newEncodeLzmaStream parms) >>= either throw g

goInput :: ByteString -> ST s (CompressStream (ST s))
goInput chunk = do
(rc, used, obuf) <- strictToLazyST $ runLzmaStream ls chunk LzmaRun bUFSIZ
(rc, used, obuf) <- strictToLazyST (noDuplicateST >>
runLzmaStream ls chunk LzmaRun bUFSIZ)

let chunk' = BS.drop used chunk

Expand All @@ -234,7 +239,8 @@ compressST parms = strictToLazyST (newEncodeLzmaStream parms) >>= either throw g
goSync action next = goSync'
where
goSync' = do
(rc, 0, obuf) <- strictToLazyST $ runLzmaStream ls BS.empty action bUFSIZ
(rc, 0, obuf) <- strictToLazyST (noDuplicateST >>
runLzmaStream ls BS.empty action bUFSIZ)
case rc of
LzmaRetOK
| BS.null obuf -> fail ("compressIO: empty output chunk during " ++ show action)
Expand All @@ -245,7 +251,7 @@ compressST parms = strictToLazyST (newEncodeLzmaStream parms) >>= either throw g
_ -> throw rc

retStreamEnd = do
!() <- strictToLazyST (endLzmaStream ls)
!() <- strictToLazyST (noDuplicateST >> endLzmaStream ls)
return CompressStreamEnd

--------------------------------------------------------------------------------
Expand Down Expand Up @@ -319,7 +325,8 @@ decompressIO parms = stToIO (newDecodeLzmaStream parms) >>= either (return . Dec

-- | Incremental decompression in the lazy 'ST' monad.
decompressST :: DecompressParams -> ST s (DecompressStream (ST s))
decompressST parms = strictToLazyST (newDecodeLzmaStream parms) >>= either (return . DecompressStreamError) go
decompressST parms = strictToLazyST (newDecodeLzmaStream parms) >>=
either (return . DecompressStreamError) go
where
bUFSIZ = 32752

Expand All @@ -332,7 +339,8 @@ decompressST parms = strictToLazyST (newDecodeLzmaStream parms) >>= either (retu
goInput chunk
| BS.null chunk = goFinish
| otherwise = do
(rc, used, obuf) <- strictToLazyST $ runLzmaStream ls chunk LzmaRun bUFSIZ
(rc, used, obuf) <- strictToLazyST (noDuplicateST >>
runLzmaStream ls chunk LzmaRun bUFSIZ)

let chunk' = BS.drop used chunk

Expand Down Expand Up @@ -361,7 +369,8 @@ decompressST parms = strictToLazyST (newDecodeLzmaStream parms) >>= either (retu
goSync action next = goSync'
where
goSync' = do
(rc, 0, obuf) <- strictToLazyST $ runLzmaStream ls BS.empty action bUFSIZ
(rc, 0, obuf) <- strictToLazyST (noDuplicateST >>
runLzmaStream ls BS.empty action bUFSIZ)
case rc of
LzmaRetOK
| BS.null obuf -> next
Expand All @@ -376,7 +385,7 @@ decompressST parms = strictToLazyST (newDecodeLzmaStream parms) >>= either (retu
eof0 = retStreamEnd BS.empty

retStreamEnd chunk' = do
!() <- strictToLazyST (endLzmaStream ls)
!() <- strictToLazyST (noDuplicateST >> endLzmaStream ls)
return (DecompressStreamEnd chunk')

-- | Small 'maybe'-ish helper distinguishing between empty and
Expand All @@ -385,3 +394,7 @@ withChunk :: t -> (ByteString -> t) -> ByteString -> t
withChunk emptyChunk nemptyChunk chunk
| BS.null chunk = emptyChunk
| otherwise = nemptyChunk chunk

-- | See <https://github.com/haskell/zlib/issues/7>
noDuplicateST :: ST.Strict.ST s ()
noDuplicateST = unsafeIOToST noDuplicate

0 comments on commit 42ca5f5

Please sign in to comment.