Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve monad transformers support #198

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions selda-sqlite/src/Database/Selda/SQLite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ sqliteOpen :: (MonadIO m, MonadMask m) => FilePath -> m (SeldaConnection SQLite)
#ifdef __HASTE__
sqliteOpen _ = error "sqliteOpen called in JS context"
#else
sqliteOpen file = do
mask $ \restore -> do
edb <- try $ liftIO $ open (pack file)
case edb of
Left e@(SQLError{}) -> do
throwM (DbError (show e))
Right db -> flip onException (liftIO (close db)) . restore $ do
absFile <- liftIO $ pack <$> makeAbsolute file
let backend = sqliteBackend db
void . liftIO $ runStmt backend "PRAGMA foreign_keys = ON;" []
newConnection backend absFile
sqliteOpen file =
bracketOnError acquire (liftIO . close) $ \db -> do
absFile <- liftIO $ pack <$> makeAbsolute file
let backend = sqliteBackend db
void . liftIO $ runStmt backend "PRAGMA foreign_keys = ON;" []
newConnection backend absFile
where
acquire = do
edb <- try $ liftIO $ open (pack file)
case edb of
Left e@(SQLError{}) -> throwM (DbError (show e))
Right db -> pure db
#endif

-- | Perform the given computation over an SQLite database.
Expand Down
1 change: 1 addition & 0 deletions selda-tests/selda-tests.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ test-suite selda-testsuite
, directory >=1.2 && <1.4
, exceptions >=0.8 && <0.11
, HUnit >=1.4 && <1.7
, mtl >=2.0 && <2.4
, selda
, selda-json
, text >=1.1 && <2.1
Expand Down
31 changes: 31 additions & 0 deletions selda-tests/test/Tests/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
module Tests.Mutable (mutableTests) where
import Control.Concurrent
import Control.Monad.Catch
#if MIN_VERSION_mtl(2, 1, 1)
import Control.Monad.Except (runExceptT, throwError)
#endif
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as Lazy (ByteString)
import Data.List hiding (groupBy, insert)
Expand Down Expand Up @@ -36,6 +39,9 @@ mutableTests freshEnv = test
, "insert time values" ~: freshEnv insertTime
, "transaction completes" ~: freshEnv transactionCompletes
, "transaction rolls back" ~: freshEnv transactionRollsBack
#if MIN_VERSION_mtl(2, 1, 1)
, "transaction rolls back (ExceptT)"~: freshEnv transactionRollsBackExceptT
#endif
, "queries are consistent" ~: freshEnv consistentQueries
, "delete deletes" ~: freshEnv deleteDeletes
, "delete everything" ~: freshEnv deleteEverything
Expand Down Expand Up @@ -192,6 +198,31 @@ transactionRollsBack = do
c1 = "チョロゴン"
c2 = "メイド最高!"

#if MIN_VERSION_mtl(2, 1, 1)
transactionRollsBackExceptT :: SeldaM b ()
transactionRollsBackExceptT = do
setup
res <- runExceptT $ transaction $ do
insert_ comments [(def, Just "Kobayashi", c1)]
insert_ comments
[ (def, Nothing, "more anonymous spam")
, (def, Just "Kobayashi", c2)
]
throwError "nope"
case res of
Right _ ->
liftIO $ assertFailure "error didn't propagate"
Left (_ :: String) -> do
cs <- query $ do
t <- select comments
restrict (t!cName .== just "Kobayashi")
return (t!cComment)
assEq "commit was not rolled back" [] cs
where
c1 = "チョロゴン"
c2 = "メイド最高!"
#endif

consistentQueries = do
setup
a <- query q
Expand Down
2 changes: 1 addition & 1 deletion selda/selda.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ library
build-depends:
base >=4.10 && <5
, bytestring >=0.10 && <0.12
, exceptions >=0.8 && <0.11
, exceptions >=0.9 && <0.11
, mtl >=2.0 && <2.4
, text >=1.0 && <2.1
, time >=1.5 && <1.13
Expand Down
78 changes: 76 additions & 2 deletions selda/src/Database/Selda/Backend/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE GeneralizedNewtypeDeriving, CPP, TypeFamilies #-}
{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}
-- | Internal backend API.
-- Using anything exported from this module may or may not invalidate any
-- safety guarantees made by Selda; use at your own peril.
Expand Down Expand Up @@ -42,9 +43,22 @@ import Data.Int (Int64)
import Control.Concurrent ( newMVar, putMVar, takeMVar, MVar )
import Control.Monad.Catch
( Exception, bracket, MonadCatch, MonadMask, MonadThrow(..) )
import Control.Monad.Error.Class ( MonadError )
#if MIN_VERSION_mtl(2, 1, 1)
import Control.Monad.Except ( ExceptT(..), mapExceptT, runExceptT )
#endif
import Control.Monad.IO.Class ( MonadIO(..) )
import Control.Monad.Reader
( MonadTrans(..), when, ReaderT(..), MonadReader(ask) )
( MonadTrans(..), when, ReaderT(..), MonadReader(ask, local, reader), mapReaderT )
import Control.Monad.RWS.Class ( MonadRWS )
import qualified Control.Monad.RWS.Lazy as Lazy ( RWST(..), mapRWST )
import qualified Control.Monad.RWS.Strict as Strict ( RWST(..), mapRWST )
import Control.Monad.State.Class ( MonadState )
import qualified Control.Monad.State.Lazy as Lazy ( StateT(..), mapStateT )
import qualified Control.Monad.State.Strict as Strict ( StateT(..), mapStateT )
import Control.Monad.Writer.Class ( MonadWriter )
import qualified Control.Monad.Writer.Lazy as Lazy ( WriterT(..), mapWriterT )
import qualified Control.Monad.Writer.Strict as Strict ( WriterT(..), mapWriterT )
import Data.Dynamic ( Typeable, Dynamic )
import qualified Data.IntMap as M
import Data.IORef
Expand Down Expand Up @@ -279,16 +293,72 @@ withBackend m = withConnection (m . connBackend)
-- | Monad transformer adding Selda SQL capabilities.
newtype SeldaT b m a = S {unS :: ReaderT (SeldaConnection b) m a}
deriving ( Functor, Applicative, Monad, MonadIO
, MonadThrow, MonadCatch, MonadMask , MonadFail
, MonadThrow, MonadCatch, MonadMask, MonadFail
, MonadError e, MonadWriter w, MonadState s
, MonadRWS r w s
)

-- This instance has to be defined manually since we want to pass through
-- SeldaT's ReaderT.
instance MonadReader r m => MonadReader r (SeldaT b m) where
ask = lift ask
local = mapSeldaT . local
reader = lift . reader

instance (MonadIO m, MonadMask m) => MonadSelda (SeldaT b m) where
type Backend (SeldaT b m) = b
withConnection m = S ask >>= m

instance MonadTrans (SeldaT b) where
lift = S . lift

instance MonadSelda m => MonadSelda (ReaderT r m) where
type Backend (ReaderT r m) = Backend m
withConnection f = ReaderT $ \r ->
withConnection (\conn -> runReaderT (f conn) r)
transact = mapReaderT transact

instance (Monoid w, MonadSelda m) => MonadSelda (Lazy.WriterT w m) where
type Backend (Lazy.WriterT w m) = Backend m
withConnection f = Lazy.WriterT $ withConnection (Lazy.runWriterT . f)
transact = Lazy.mapWriterT transact

instance (Monoid w, MonadSelda m) => MonadSelda (Strict.WriterT w m) where
type Backend (Strict.WriterT w m) = Backend m
withConnection f = Strict.WriterT $ withConnection (Strict.runWriterT . f)
transact = Strict.mapWriterT transact

instance MonadSelda m => MonadSelda (Lazy.StateT s m) where
type Backend (Lazy.StateT s m) = Backend m
withConnection f = Lazy.StateT $ \s ->
withConnection (\conn -> Lazy.runStateT (f conn) s)
transact = Lazy.mapStateT transact

instance MonadSelda m => MonadSelda (Strict.StateT s m) where
type Backend (Strict.StateT s m) = Backend m
withConnection f = Strict.StateT $ \s ->
withConnection (\conn -> Strict.runStateT (f conn) s)
transact = Strict.mapStateT transact

instance (Monoid w, MonadSelda m) => MonadSelda (Lazy.RWST r w s m) where
type Backend (Lazy.RWST r w s m) = Backend m
withConnection f = Lazy.RWST $ \r s ->
withConnection (\conn -> Lazy.runRWST (f conn) r s)
transact = Lazy.mapRWST transact

instance (Monoid w, MonadSelda m) => MonadSelda (Strict.RWST r w s m) where
type Backend (Strict.RWST r w s m) = Backend m
withConnection f = Strict.RWST $ \r s ->
withConnection (\conn -> Strict.runRWST (f conn) r s)
transact = Strict.mapRWST transact

#if MIN_VERSION_mtl(2, 1, 1)
instance MonadSelda m => MonadSelda (ExceptT e m) where
type Backend (ExceptT e m) = Backend m
withConnection f = ExceptT $ withConnection (runExceptT . f)
transact = mapExceptT transact
#endif

-- | The simplest form of Selda computation; 'SeldaT' specialized to 'IO'.
type SeldaM b = SeldaT b IO

Expand All @@ -308,3 +378,7 @@ runSeldaT m c =
when closed $ do
liftIO $ throwM $ DbError "runSeldaT called with a closed connection"
runReaderT (unS m) c

-- | Transform the computation inside a 'SeldaT'.
mapSeldaT :: (m a -> n b) -> SeldaT b' m a -> SeldaT b' n b
mapSeldaT f m = S $ mapReaderT f (unS m)
15 changes: 7 additions & 8 deletions selda/src/Database/Selda/Frontend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ import Data.Text (Text)
import Control.Monad ( void )
import Control.Monad.Catch
( bracket_,
onException,
try,
MonadCatch,
MonadMask(mask),
MonadThrow(throwM) )
MonadMask(generalBracket),
MonadThrow(throwM), ExitCase (..) )
import Control.Monad.IO.Class ( MonadIO(..) )

-- | Run a query within a Selda monad. In practice, this is often a 'SeldaT'
Expand Down Expand Up @@ -259,11 +258,11 @@ tryDropTable = void . flip exec [] . compileDropTable Ignore
-- will be rolled back and the exception re-thrown, even if the exception
-- is caught and handled within the transaction.
transaction :: (MonadSelda m, MonadMask m) => m a -> m a
transaction m = mask $ \restore -> transact $ do
void $ exec "BEGIN TRANSACTION" []
x <- restore m `onException` void (exec "ROLLBACK" [])
void $ exec "COMMIT" []
return x
transaction m =
fst <$> generalBracket (exec "BEGIN TRANSACTION" []) (const finish) (const m)
where
finish (ExitCaseSuccess _) = exec "COMMIT" []
finish _ = exec "ROLLBACK" []

-- | Run the given computation as a transaction without enforcing foreign key
-- constraints.
Expand Down