diff --git a/selda-sqlite/src/Database/Selda/SQLite.hs b/selda-sqlite/src/Database/Selda/SQLite.hs index 352ff34..ad5bb4d 100644 --- a/selda-sqlite/src/Database/Selda/SQLite.hs +++ b/selda-sqlite/src/Database/Selda/SQLite.hs @@ -32,17 +32,18 @@ sqliteOpen :: (MonadIO m, MonadMask m) => FilePath -> m (SeldaConnection SQLite) #ifdef __HASTE__ sqliteOpen _ = error "sqliteOpen called in JS context" #else -sqliteOpen file = do - mask $ \restore -> do - edb <- try $ liftIO $ open (pack file) - case edb of - Left e@(SQLError{}) -> do - throwM (DbError (show e)) - Right db -> flip onException (liftIO (close db)) . restore $ do - absFile <- liftIO $ pack <$> makeAbsolute file - let backend = sqliteBackend db - void . liftIO $ runStmt backend "PRAGMA foreign_keys = ON;" [] - newConnection backend absFile +sqliteOpen file = + bracketOnError acquire (liftIO . close) $ \db -> do + absFile <- liftIO $ pack <$> makeAbsolute file + let backend = sqliteBackend db + void . liftIO $ runStmt backend "PRAGMA foreign_keys = ON;" [] + newConnection backend absFile + where + acquire = do + edb <- try $ liftIO $ open (pack file) + case edb of + Left e@(SQLError{}) -> throwM (DbError (show e)) + Right db -> pure db #endif -- | Perform the given computation over an SQLite database. diff --git a/selda-tests/selda-tests.cabal b/selda-tests/selda-tests.cabal index 40ed106..953fefd 100644 --- a/selda-tests/selda-tests.cabal +++ b/selda-tests/selda-tests.cabal @@ -38,6 +38,7 @@ test-suite selda-testsuite , directory >=1.2 && <1.4 , exceptions >=0.8 && <0.11 , HUnit >=1.4 && <1.7 + , mtl >=2.0 && <2.4 , selda , selda-json , text >=1.1 && <2.1 diff --git a/selda-tests/test/Tests/Mutable.hs b/selda-tests/test/Tests/Mutable.hs index dfec5be..8793153 100644 --- a/selda-tests/test/Tests/Mutable.hs +++ b/selda-tests/test/Tests/Mutable.hs @@ -5,6 +5,9 @@ module Tests.Mutable (mutableTests) where import Control.Concurrent import Control.Monad.Catch +#if MIN_VERSION_mtl(2, 1, 1) +import Control.Monad.Except (runExceptT, throwError) +#endif import Data.ByteString (ByteString) import qualified Data.ByteString.Lazy as Lazy (ByteString) import Data.List hiding (groupBy, insert) @@ -36,6 +39,9 @@ mutableTests freshEnv = test , "insert time values" ~: freshEnv insertTime , "transaction completes" ~: freshEnv transactionCompletes , "transaction rolls back" ~: freshEnv transactionRollsBack +#if MIN_VERSION_mtl(2, 1, 1) + , "transaction rolls back (ExceptT)"~: freshEnv transactionRollsBackExceptT +#endif , "queries are consistent" ~: freshEnv consistentQueries , "delete deletes" ~: freshEnv deleteDeletes , "delete everything" ~: freshEnv deleteEverything @@ -192,6 +198,31 @@ transactionRollsBack = do c1 = "チョロゴン" c2 = "メイド最高!" +#if MIN_VERSION_mtl(2, 1, 1) +transactionRollsBackExceptT :: SeldaM b () +transactionRollsBackExceptT = do + setup + res <- runExceptT $ transaction $ do + insert_ comments [(def, Just "Kobayashi", c1)] + insert_ comments + [ (def, Nothing, "more anonymous spam") + , (def, Just "Kobayashi", c2) + ] + throwError "nope" + case res of + Right _ -> + liftIO $ assertFailure "error didn't propagate" + Left (_ :: String) -> do + cs <- query $ do + t <- select comments + restrict (t!cName .== just "Kobayashi") + return (t!cComment) + assEq "commit was not rolled back" [] cs + where + c1 = "チョロゴン" + c2 = "メイド最高!" +#endif + consistentQueries = do setup a <- query q diff --git a/selda/selda.cabal b/selda/selda.cabal index 83a6fbf..9c1e44d 100644 --- a/selda/selda.cabal +++ b/selda/selda.cabal @@ -73,7 +73,7 @@ library build-depends: base >=4.10 && <5 , bytestring >=0.10 && <0.12 - , exceptions >=0.8 && <0.11 + , exceptions >=0.9 && <0.11 , mtl >=2.0 && <2.4 , text >=1.0 && <2.1 , time >=1.5 && <1.13 diff --git a/selda/src/Database/Selda/Backend/Internal.hs b/selda/src/Database/Selda/Backend/Internal.hs index b52fafa..6eeb664 100644 --- a/selda/src/Database/Selda/Backend/Internal.hs +++ b/selda/src/Database/Selda/Backend/Internal.hs @@ -1,4 +1,5 @@ {-# LANGUAGE GeneralizedNewtypeDeriving, CPP, TypeFamilies #-} +{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-} -- | Internal backend API. -- Using anything exported from this module may or may not invalidate any -- safety guarantees made by Selda; use at your own peril. @@ -42,9 +43,22 @@ import Data.Int (Int64) import Control.Concurrent ( newMVar, putMVar, takeMVar, MVar ) import Control.Monad.Catch ( Exception, bracket, MonadCatch, MonadMask, MonadThrow(..) ) +import Control.Monad.Error.Class ( MonadError ) +#if MIN_VERSION_mtl(2, 1, 1) +import Control.Monad.Except ( ExceptT(..), mapExceptT, runExceptT ) +#endif import Control.Monad.IO.Class ( MonadIO(..) ) import Control.Monad.Reader - ( MonadTrans(..), when, ReaderT(..), MonadReader(ask) ) + ( MonadTrans(..), when, ReaderT(..), MonadReader(ask, local, reader), mapReaderT ) +import Control.Monad.RWS.Class ( MonadRWS ) +import qualified Control.Monad.RWS.Lazy as Lazy ( RWST(..), mapRWST ) +import qualified Control.Monad.RWS.Strict as Strict ( RWST(..), mapRWST ) +import Control.Monad.State.Class ( MonadState ) +import qualified Control.Monad.State.Lazy as Lazy ( StateT(..), mapStateT ) +import qualified Control.Monad.State.Strict as Strict ( StateT(..), mapStateT ) +import Control.Monad.Writer.Class ( MonadWriter ) +import qualified Control.Monad.Writer.Lazy as Lazy ( WriterT(..), mapWriterT ) +import qualified Control.Monad.Writer.Strict as Strict ( WriterT(..), mapWriterT ) import Data.Dynamic ( Typeable, Dynamic ) import qualified Data.IntMap as M import Data.IORef @@ -279,9 +293,18 @@ withBackend m = withConnection (m . connBackend) -- | Monad transformer adding Selda SQL capabilities. newtype SeldaT b m a = S {unS :: ReaderT (SeldaConnection b) m a} deriving ( Functor, Applicative, Monad, MonadIO - , MonadThrow, MonadCatch, MonadMask , MonadFail + , MonadThrow, MonadCatch, MonadMask, MonadFail + , MonadError e, MonadWriter w, MonadState s + , MonadRWS r w s ) +-- This instance has to be defined manually since we want to pass through +-- SeldaT's ReaderT. +instance MonadReader r m => MonadReader r (SeldaT b m) where + ask = lift ask + local = mapSeldaT . local + reader = lift . reader + instance (MonadIO m, MonadMask m) => MonadSelda (SeldaT b m) where type Backend (SeldaT b m) = b withConnection m = S ask >>= m @@ -289,6 +312,53 @@ instance (MonadIO m, MonadMask m) => MonadSelda (SeldaT b m) where instance MonadTrans (SeldaT b) where lift = S . lift +instance MonadSelda m => MonadSelda (ReaderT r m) where + type Backend (ReaderT r m) = Backend m + withConnection f = ReaderT $ \r -> + withConnection (\conn -> runReaderT (f conn) r) + transact = mapReaderT transact + +instance (Monoid w, MonadSelda m) => MonadSelda (Lazy.WriterT w m) where + type Backend (Lazy.WriterT w m) = Backend m + withConnection f = Lazy.WriterT $ withConnection (Lazy.runWriterT . f) + transact = Lazy.mapWriterT transact + +instance (Monoid w, MonadSelda m) => MonadSelda (Strict.WriterT w m) where + type Backend (Strict.WriterT w m) = Backend m + withConnection f = Strict.WriterT $ withConnection (Strict.runWriterT . f) + transact = Strict.mapWriterT transact + +instance MonadSelda m => MonadSelda (Lazy.StateT s m) where + type Backend (Lazy.StateT s m) = Backend m + withConnection f = Lazy.StateT $ \s -> + withConnection (\conn -> Lazy.runStateT (f conn) s) + transact = Lazy.mapStateT transact + +instance MonadSelda m => MonadSelda (Strict.StateT s m) where + type Backend (Strict.StateT s m) = Backend m + withConnection f = Strict.StateT $ \s -> + withConnection (\conn -> Strict.runStateT (f conn) s) + transact = Strict.mapStateT transact + +instance (Monoid w, MonadSelda m) => MonadSelda (Lazy.RWST r w s m) where + type Backend (Lazy.RWST r w s m) = Backend m + withConnection f = Lazy.RWST $ \r s -> + withConnection (\conn -> Lazy.runRWST (f conn) r s) + transact = Lazy.mapRWST transact + +instance (Monoid w, MonadSelda m) => MonadSelda (Strict.RWST r w s m) where + type Backend (Strict.RWST r w s m) = Backend m + withConnection f = Strict.RWST $ \r s -> + withConnection (\conn -> Strict.runRWST (f conn) r s) + transact = Strict.mapRWST transact + +#if MIN_VERSION_mtl(2, 1, 1) +instance MonadSelda m => MonadSelda (ExceptT e m) where + type Backend (ExceptT e m) = Backend m + withConnection f = ExceptT $ withConnection (runExceptT . f) + transact = mapExceptT transact +#endif + -- | The simplest form of Selda computation; 'SeldaT' specialized to 'IO'. type SeldaM b = SeldaT b IO @@ -308,3 +378,7 @@ runSeldaT m c = when closed $ do liftIO $ throwM $ DbError "runSeldaT called with a closed connection" runReaderT (unS m) c + +-- | Transform the computation inside a 'SeldaT'. +mapSeldaT :: (m a -> n b) -> SeldaT b' m a -> SeldaT b' n b +mapSeldaT f m = S $ mapReaderT f (unS m) diff --git a/selda/src/Database/Selda/Frontend.hs b/selda/src/Database/Selda/Frontend.hs index 247ae6d..4c8779c 100644 --- a/selda/src/Database/Selda/Frontend.hs +++ b/selda/src/Database/Selda/Frontend.hs @@ -44,11 +44,10 @@ import Data.Text (Text) import Control.Monad ( void ) import Control.Monad.Catch ( bracket_, - onException, try, MonadCatch, - MonadMask(mask), - MonadThrow(throwM) ) + MonadMask(generalBracket), + MonadThrow(throwM), ExitCase (..) ) import Control.Monad.IO.Class ( MonadIO(..) ) -- | Run a query within a Selda monad. In practice, this is often a 'SeldaT' @@ -259,11 +258,11 @@ tryDropTable = void . flip exec [] . compileDropTable Ignore -- will be rolled back and the exception re-thrown, even if the exception -- is caught and handled within the transaction. transaction :: (MonadSelda m, MonadMask m) => m a -> m a -transaction m = mask $ \restore -> transact $ do - void $ exec "BEGIN TRANSACTION" [] - x <- restore m `onException` void (exec "ROLLBACK" []) - void $ exec "COMMIT" [] - return x +transaction m = + fst <$> generalBracket (exec "BEGIN TRANSACTION" []) (const finish) (const m) + where + finish (ExitCaseSuccess _) = exec "COMMIT" [] + finish _ = exec "ROLLBACK" [] -- | Run the given computation as a transaction without enforcing foreign key -- constraints.