diff --git a/ki/ki.cabal b/ki/ki.cabal index f74c0bd..2ed6cd1 100644 --- a/ki/ki.cabal +++ b/ki/ki.cabal @@ -90,6 +90,7 @@ library other-modules: Ki.Internal.ByteCount Ki.Internal.IO + Ki.Internal.NonblockingSTM Ki.Internal.Scope Ki.Internal.Thread diff --git a/ki/src/Ki/Internal/NonblockingSTM.hs b/ki/src/Ki/Internal/NonblockingSTM.hs new file mode 100644 index 0000000..4d94f81 --- /dev/null +++ b/ki/src/Ki/Internal/NonblockingSTM.hs @@ -0,0 +1,36 @@ +-- | STM minus retry. These STM actions are guaranteed not to block, and thus guaranteed not to be interrupted by an +-- async exception. +module Ki.Internal.NonblockingSTM + ( NonblockingSTM, + nonblockingAtomically, + nonblockingThrowSTM, + + -- * TVar + nonblockingReadTVar, + nonblockingWriteTVar', + ) +where + +import Control.Exception (Exception) +import Data.Coerce (coerce) +import GHC.Conc (STM, TVar, atomically, readTVar, throwSTM, writeTVar) + +newtype NonblockingSTM a + = NonblockingSTM (STM a) + deriving newtype (Applicative, Functor, Monad) + +nonblockingAtomically :: forall a. NonblockingSTM a -> IO a +nonblockingAtomically = + coerce @(STM a -> IO a) atomically + +nonblockingThrowSTM :: forall e x. (Exception e) => e -> NonblockingSTM x +nonblockingThrowSTM = + coerce @(e -> STM x) throwSTM + +nonblockingReadTVar :: forall a. TVar a -> NonblockingSTM a +nonblockingReadTVar = + coerce @(TVar a -> STM a) readTVar + +nonblockingWriteTVar' :: forall a. TVar a -> a -> NonblockingSTM () +nonblockingWriteTVar' var !x = + NonblockingSTM (writeTVar var x) diff --git a/ki/src/Ki/Internal/Scope.hs b/ki/src/Ki/Internal/Scope.hs index 8c65f7a..678cb4f 100644 --- a/ki/src/Ki/Internal/Scope.hs +++ b/ki/src/Ki/Internal/Scope.hs @@ -47,9 +47,9 @@ import GHC.Conc ) import GHC.Conc.Sync (readTVarIO) import GHC.IO (unsafeUnmask) -import Ki.Internal.ByteCount import IntSupply (IntSupply) import qualified IntSupply +import Ki.Internal.ByteCount import Ki.Internal.IO ( IOResult (..), UnexceptionalIO (..), @@ -59,6 +59,7 @@ import Ki.Internal.IO unexceptionalTryEither, uninterruptiblyMasked, ) +import Ki.Internal.NonblockingSTM import Ki.Internal.Thread -- | A scope. @@ -229,13 +230,13 @@ spawn -- or MaskedUninterruptible, to avoid a branch on parentMaskingState. interruptiblyMasked do -- Record the thread as being about to start. Not allowed to retry. - atomically do - n <- readTVar statusVar + nonblockingAtomically do + n <- nonblockingReadTVar statusVar assert (n >= -2) do case n of - Open -> writeTVar statusVar $! n + 1 - Closing -> throwSTM ScopeClosing - Closed -> throwSTM (ErrorCall "ki: scope closed") + Open -> nonblockingWriteTVar' statusVar (n + 1) + Closing -> nonblockingThrowSTM ScopeClosing + Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed") childId <- IntSupply.next nextChildIdSupply @@ -245,11 +246,9 @@ spawn childThreadId <- myThreadId labelThread childThreadId label - case allocationLimit of - Nothing -> pure () - Just bytes -> do - setAllocationCounter (byteCountToInt64 bytes) - enableAllocationLimit + for_ allocationLimit \bytes -> do + setAllocationCounter (byteCountToInt64 bytes) + enableAllocationLimit let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one. atRequestedMaskingState :: IO a -> IO a @@ -261,12 +260,12 @@ spawn runUnexceptionalIO (action childId atRequestedMaskingState) - atomically (unrecordChild childrenVar childId) + nonblockingAtomically (unrecordChild childrenVar childId) -- Record the child as having started. Not allowed to retry. - atomically do - n <- readTVar statusVar - writeTVar statusVar $! n - 1 + nonblockingAtomically do + n <- nonblockingReadTVar statusVar + nonblockingWriteTVar' statusVar (n - 1) recordChild childrenVar childId childThreadId pure childThreadId @@ -275,23 +274,19 @@ spawn -- -- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself) -- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself) --- --- Never retries. -recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> STM () +recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM () recordChild childrenVar childId childThreadId = do - children <- readTVar childrenVar - writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children + children <- nonblockingReadTVar childrenVar + nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children) -- Unrecord a child (ourselves) by either: -- -- * Flipping `Just childThreadId` to `Nothing` (common case: parent recorded us first) -- * Flipping `Nothing` to `Just undefined` (uncommon case: we terminate and unrecord before parent can record us). --- --- Never retries. -unrecordChild :: TVar (IntMap ThreadId) -> Tid -> STM () +unrecordChild :: TVar (IntMap ThreadId) -> Tid -> NonblockingSTM () unrecordChild childrenVar childId = do - children <- readTVar childrenVar - writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children + children <- nonblockingReadTVar childrenVar + nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children) -- | Wait until all threads created within a scope terminate. awaitAll :: Scope -> STM ()