Skip to content

Commit

Permalink
add NonblockingSTM type
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchellwrosen committed Nov 28, 2023
1 parent a5da3ba commit 30d26f4
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
1 change: 1 addition & 0 deletions ki/ki.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ library
other-modules:
Ki.Internal.ByteCount
Ki.Internal.IO
Ki.Internal.NonblockingSTM
Ki.Internal.Scope
Ki.Internal.Thread

Expand Down
36 changes: 36 additions & 0 deletions ki/src/Ki/Internal/NonblockingSTM.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
-- | STM minus retry. These STM actions are guaranteed not to block, and thus guaranteed not to be interrupted by an
-- async exception.
module Ki.Internal.NonblockingSTM
( NonblockingSTM,
nonblockingAtomically,
nonblockingThrowSTM,

-- * TVar
nonblockingReadTVar,
nonblockingWriteTVar',
)
where

import Control.Exception (Exception)
import Data.Coerce (coerce)
import GHC.Conc (STM, TVar, atomically, readTVar, throwSTM, writeTVar)

newtype NonblockingSTM a
= NonblockingSTM (STM a)
deriving newtype (Applicative, Functor, Monad)

nonblockingAtomically :: forall a. NonblockingSTM a -> IO a
nonblockingAtomically =
coerce @(STM a -> IO a) atomically

nonblockingThrowSTM :: forall e x. (Exception e) => e -> NonblockingSTM x
nonblockingThrowSTM =
coerce @(e -> STM x) throwSTM

nonblockingReadTVar :: forall a. TVar a -> NonblockingSTM a
nonblockingReadTVar =
coerce @(TVar a -> STM a) readTVar

nonblockingWriteTVar' :: forall a. TVar a -> a -> NonblockingSTM ()
nonblockingWriteTVar' var !x =
NonblockingSTM (writeTVar var x)
45 changes: 20 additions & 25 deletions ki/src/Ki/Internal/Scope.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ import GHC.Conc
)
import GHC.Conc.Sync (readTVarIO)
import GHC.IO (unsafeUnmask)
import Ki.Internal.ByteCount
import IntSupply (IntSupply)
import qualified IntSupply
import Ki.Internal.ByteCount
import Ki.Internal.IO
( IOResult (..),
UnexceptionalIO (..),
Expand All @@ -59,6 +59,7 @@ import Ki.Internal.IO
unexceptionalTryEither,
uninterruptiblyMasked,
)
import Ki.Internal.NonblockingSTM
import Ki.Internal.Thread

-- | A scope.
Expand Down Expand Up @@ -229,13 +230,13 @@ spawn
-- or MaskedUninterruptible, to avoid a branch on parentMaskingState.
interruptiblyMasked do
-- Record the thread as being about to start. Not allowed to retry.
atomically do
n <- readTVar statusVar
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
assert (n >= -2) do
case n of
Open -> writeTVar statusVar $! n + 1
Closing -> throwSTM ScopeClosing
Closed -> throwSTM (ErrorCall "ki: scope closed")
Open -> nonblockingWriteTVar' statusVar (n + 1)
Closing -> nonblockingThrowSTM ScopeClosing
Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed")

childId <- IntSupply.next nextChildIdSupply

Expand All @@ -245,11 +246,9 @@ spawn
childThreadId <- myThreadId
labelThread childThreadId label

case allocationLimit of
Nothing -> pure ()
Just bytes -> do
setAllocationCounter (byteCountToInt64 bytes)
enableAllocationLimit
for_ allocationLimit \bytes -> do
setAllocationCounter (byteCountToInt64 bytes)
enableAllocationLimit

let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one.
atRequestedMaskingState :: IO a -> IO a
Expand All @@ -261,12 +260,12 @@ spawn

runUnexceptionalIO (action childId atRequestedMaskingState)

atomically (unrecordChild childrenVar childId)
nonblockingAtomically (unrecordChild childrenVar childId)

-- Record the child as having started. Not allowed to retry.
atomically do
n <- readTVar statusVar
writeTVar statusVar $! n - 1
nonblockingAtomically do
n <- nonblockingReadTVar statusVar
nonblockingWriteTVar' statusVar (n - 1)
recordChild childrenVar childId childThreadId

pure childThreadId
Expand All @@ -275,23 +274,19 @@ spawn
--
-- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself)
-- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself)
--
-- Never retries.
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> STM ()
recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM ()
recordChild childrenVar childId childThreadId = do
children <- readTVar childrenVar
writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children
children <- nonblockingReadTVar childrenVar
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children)

-- Unrecord a child (ourselves) by either:
--
-- * Flipping `Just childThreadId` to `Nothing` (common case: parent recorded us first)
-- * Flipping `Nothing` to `Just undefined` (uncommon case: we terminate and unrecord before parent can record us).
--
-- Never retries.
unrecordChild :: TVar (IntMap ThreadId) -> Tid -> STM ()
unrecordChild :: TVar (IntMap ThreadId) -> Tid -> NonblockingSTM ()
unrecordChild childrenVar childId = do
children <- readTVar childrenVar
writeTVar childrenVar $! IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children
children <- nonblockingReadTVar childrenVar
nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just undefined) (const Nothing)) childId children)

-- | Wait until all threads created within a scope terminate.
awaitAll :: Scope -> STM ()
Expand Down

0 comments on commit 30d26f4

Please sign in to comment.