diff --git a/Control/Monad/Invert.hs b/Control/Monad/Invert.hs index 6eb791a..f322b75 100644 --- a/Control/Monad/Invert.hs +++ b/Control/Monad/Invert.hs @@ -9,6 +9,7 @@ module Control.Monad.Invert , unblock , bracket , bracket_ + , onException -- * Memory allocation , alloca , allocaBytes @@ -29,8 +30,9 @@ import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import Foreign.ForeignPtr (ForeignPtr) import qualified Foreign.ForeignPtr as F +import Control.Monad.IO.Class (MonadIO) -class Monad m => MonadInvertIO m where +class MonadIO m => MonadInvertIO m where data InvertedIO m :: * -> * type InvertedArg m invertIO :: m a -> InvertedArg m -> IO (InvertedIO m a) @@ -81,6 +83,10 @@ finally :: MonadInvertIO m => m a -> m b -> m a finally action after = revertIO $ \a -> invertIO action a `E.finally` invertIO after a +onException :: MonadInvertIO m => m a -> m b -> m a +onException action after = + revertIO $ \a -> invertIO action a `E.onException` invertIO after a + catch :: (E.Exception e, MonadInvertIO m) => m a -> (e -> m a) -> m a catch action handler = revertIO $ \a -> invertIO action a `E.catch` (\e -> invertIO (handler e) a) diff --git a/runtests.hs b/runtests.hs index c71b157..7f4b955 100644 --- a/runtests.hs +++ b/runtests.hs @@ -43,6 +43,7 @@ testSuite s run = testGroup s -- FIXME test block and unblock , testCase "bracket" $ case_bracket run , testCase "bracket_" $ case_bracket_ run + , testCase "onException" $ case_onException run ] ignore :: IO () -> IO () @@ -98,6 +99,20 @@ case_bracket_ run = do j <- readIORef i j @?= 4 +case_onException :: (MonadIO m, MonadInvertIO m) => (m () -> IO ()) -> Assertion +case_onException run = do + i <- newIORef one + ignore $ run $ onException + (liftIO (writeIORef i 2) >> error "ignored") + (liftIO $ writeIORef i 3) + j <- readIORef i + j @?= 3 + ignore $ run $ onException + (liftIO $ writeIORef i 4) + (liftIO $ writeIORef i 5) + k <- readIORef i + k @?= 4 + case_throwError :: Assertion case_throwError = do i <- newIORef one