From f5e9a45fa8059e467b0eaf597bf186355f4b0968 Mon Sep 17 00:00:00 2001 From: Mitchell Rosen Date: Mon, 27 Nov 2023 22:33:30 -0500 Subject: [PATCH] refactoring, cleanup --- ki/src/Ki/Internal/IO.hs | 28 ++++---- ki/src/Ki/Internal/Scope.hs | 133 +++++++++++++++++++----------------- 2 files changed, 87 insertions(+), 74 deletions(-) diff --git a/ki/src/Ki/Internal/IO.hs b/ki/src/Ki/Internal/IO.hs index 3a3688f..719de51 100644 --- a/ki/src/Ki/Internal/IO.hs +++ b/ki/src/Ki/Internal/IO.hs @@ -24,13 +24,15 @@ where import Control.Exception import Control.Monad (join) import Data.Coerce (coerce) +import Data.Maybe (isJust) import GHC.Base (maskAsyncExceptions#, maskUninterruptible#) import GHC.Conc (STM, ThreadId (ThreadId), catchSTM) import GHC.Exts (Int (I#), fork#, forkOn#) import GHC.IO (IO (IO)) import Prelude --- A little promise that this IO action cannot throw an exception. +-- A little promise that this IO action cannot throw an exception (*including* async exceptions, which you normally +-- think of as being able to strike at any time). -- -- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require -- un-exceptiony IO actions for correctness, so here we are. @@ -42,13 +44,17 @@ data IOResult a = Failure !SomeException -- sync or async exception | Success a +-- Try an action, catching any exception it throws. +-- +-- The caller is responsible for ensuring that async exceptions are masked (at whatever masking level is appropriate), +-- as (again) `UnexceptionalIO` implies async exceptions won't be thrown either. unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a) unexceptionalTry action = UnexceptionalIO do (Success <$> action) `catch` \exception -> pure (Failure exception) --- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need. +-- Like try, but with continuations. unexceptionalTryEither :: forall a b. (SomeException -> UnexceptionalIO b) -> @@ -63,20 +69,18 @@ unexceptionalTryEither onFailure onSuccess action = (pure . coerce @_ @(SomeException -> IO b) onFailure) isAsyncException :: SomeException -> Bool -isAsyncException exception = - case fromException @SomeAsyncException exception of - Nothing -> False - Just _ -> True +isAsyncException = + isJust . fromException @SomeAsyncException -- | Call an action with asynchronous exceptions interruptibly masked. -interruptiblyMasked :: IO a -> IO a -interruptiblyMasked (IO io) = - IO (maskAsyncExceptions# io) +interruptiblyMasked :: forall a. IO a -> IO a +interruptiblyMasked = + coerce (maskAsyncExceptions# @a) -- | Call an action with asynchronous exceptions uninterruptibly masked. -uninterruptiblyMasked :: IO a -> IO a -uninterruptiblyMasked (IO io) = - IO (maskUninterruptible# io) +uninterruptiblyMasked :: forall a. IO a -> IO a +uninterruptiblyMasked = + coerce (maskUninterruptible# @a) -- Like try, but with continuations tryEitherSTM :: (Exception e) => (e -> STM b) -> (a -> STM b) -> STM a -> STM b diff --git a/ki/src/Ki/Internal/Scope.hs b/ki/src/Ki/Internal/Scope.hs index 129f26b..bce7b7f 100644 --- a/ki/src/Ki/Internal/Scope.hs +++ b/ki/src/Ki/Internal/Scope.hs @@ -223,63 +223,74 @@ allocateScope = do -- Spawn a thread in a scope, providing it its child id and a function that sets the masking state to the requested -- masking state. The given action is called with async exceptions interruptibly masked. -spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ThreadId -spawn - Scope {childrenVar, nextChildIdSupply, statusVar} - ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState} - action = do - -- Interruptible mask is enough so long as none of the STM operations below block. - -- - -- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible - -- or MaskedUninterruptible, to avoid a branch on parentMaskingState. - interruptiblyMasked do - -- Record the thread as being about to start. Not allowed to retry. - nonblockingAtomically do - n <- nonblockingReadTVar statusVar - assert (n >= -2) do - case n of - Open -> nonblockingWriteTVar' statusVar (n + 1) - Closing -> nonblockingThrowSTM ScopeClosing - Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed") - - childId <- IntSupply.next nextChildIdSupply - - childThreadId <- - forkWithAffinity affinity do - when (not (null label)) do - childThreadId <- myThreadId - labelThread childThreadId label - - for_ allocationLimit \bytes -> do - setAllocationCounter (byteCountToInt64 bytes) - enableAllocationLimit - - let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one. - atRequestedMaskingState :: IO a -> IO a - atRequestedMaskingState = - case requestedChildMaskingState of - Unmasked -> unsafeUnmask - MaskedInterruptible -> id - MaskedUninterruptible -> uninterruptiblyMasked - - runUnexceptionalIO (action childId atRequestedMaskingState) - - nonblockingAtomically (unrecordChild childrenVar childId) - - -- Record the child as having started. Not allowed to retry. - nonblockingAtomically do - n <- nonblockingReadTVar statusVar - nonblockingWriteTVar' statusVar (n - 1) - recordChild childrenVar childId childThreadId - - pure childThreadId +spawn :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds +spawn scope@Scope {childrenVar, statusVar} options action = do + -- Interruptible mask is enough so long as none of the STM operations below block. + -- + -- Unconditionally set masking state to MaskedInterruptible, even though we might already be at MaskedInterruptible + -- or MaskedUninterruptible, to avoid a branch on parentMaskingState. + interruptiblyMasked do + -- Record the thread as being about to start. Not allowed to retry. + nonblockingAtomically do + status <- nonblockingReadTVar statusVar + assert (status >= -2) do + case status of + Open -> nonblockingWriteTVar' statusVar (status + 1) + Closing -> nonblockingThrowSTM ScopeClosing + Closed -> nonblockingThrowSTM (ErrorCall "ki: scope closed") + + childIds <- spawnChild scope options action + + -- Record the child as having started. Not allowed to retry. + nonblockingAtomically do + starting <- nonblockingReadTVar statusVar + assert (starting >= 1) do + nonblockingWriteTVar' statusVar (starting - 1) + recordChild childrenVar childIds + + pure childIds + +data ChildIds + = ChildIds + {-# UNPACK #-} !Tid + {-# UNPACK #-} !ThreadId + +spawnChild :: Scope -> ThreadOptions -> (Tid -> (forall x. IO x -> IO x) -> UnexceptionalIO ()) -> IO ChildIds +spawnChild scope options action = do + childId <- IntSupply.next nextChildIdSupply + childThreadId <- + forkWithAffinity affinity do + when (not (null label)) do + childThreadId <- myThreadId + labelThread childThreadId label + + for_ allocationLimit \bytes -> do + setAllocationCounter (byteCountToInt64 bytes) + enableAllocationLimit + + let -- Action that sets the masking state from the current (MaskedInterruptible) to the requested one. + atRequestedMaskingState :: IO a -> IO a + atRequestedMaskingState = + case requestedChildMaskingState of + Unmasked -> unsafeUnmask + MaskedInterruptible -> id + MaskedUninterruptible -> uninterruptiblyMasked + + runUnexceptionalIO (action childId atRequestedMaskingState) + + nonblockingAtomically (unrecordChild childrenVar childId) + pure (ChildIds childId childThreadId) + where + Scope {childrenVar, nextChildIdSupply} = scope + ThreadOptions {affinity, allocationLimit, label, maskingState = requestedChildMaskingState} = options +{-# INLINE spawnChild #-} -- Record our child by either: -- -- * Flipping `Nothing` to `Just childThreadId` (common case: we record child before it unrecords itself) -- * Flipping `Just _` to `Nothing` (uncommon case: we observe that a child already unrecorded itself) -recordChild :: TVar (IntMap ThreadId) -> Tid -> ThreadId -> NonblockingSTM () -recordChild childrenVar childId childThreadId = do +recordChild :: TVar (IntMap ThreadId) -> ChildIds -> NonblockingSTM () +recordChild childrenVar (ChildIds childId childThreadId) = do children <- nonblockingReadTVar childrenVar nonblockingWriteTVar' childrenVar (IntMap.Lazy.alter (maybe (Just childThreadId) (const Nothing)) childId children) @@ -298,7 +309,7 @@ awaitAll Scope {childrenVar, statusVar} = do children <- readTVar childrenVar guard (IntMap.Lazy.null children) status <- readTVar statusVar - case status of + assert (status >= -2) case status of Open -> guard (status == 0) Closing -> retry -- block until closed Closed -> pure () @@ -321,14 +332,12 @@ forkWith :: Scope -> ThreadOptions -> IO a -> IO (Thread a) forkWith scope opts action = do resultVar <- newTVarIO NoResultYet let done result = UnexceptionalIO (atomically (writeTVar resultVar result)) - ident <- + ChildIds _ childThreadId <- spawn scope opts \childId masking -> do - result <- unexceptionalTry (masking action) - case result of + unexceptionalTry (masking action) >>= \case Failure exception -> do - when - (not (isScopeClosingException exception)) - (propagateException scope childId exception) + when (not (isScopeClosingException exception)) do + propagateException scope childId exception -- even put async exceptions that we propagated. this isn't totally ideal because a caller awaiting this -- thread would not be able to distinguish between async exceptions delivered to this thread, or itself done (BadResult exception) @@ -338,7 +347,7 @@ forkWith scope opts action = do NoResultYet -> retry BadResult exception -> throwSTM exception GoodResult value -> pure value - pure (makeThread ident doAwait) + pure (makeThread childThreadId doAwait) -- | Variant of 'Ki.forkWith' for threads that never return. forkWith_ :: Scope -> ThreadOptions -> IO Void -> IO () @@ -369,7 +378,7 @@ forkTryWith :: forall e a. (Exception e) => Scope -> ThreadOptions -> IO a -> IO forkTryWith scope opts action = do resultVar <- newTVarIO NoResultYet let done result = UnexceptionalIO (atomically (writeTVar resultVar result)) - childThreadId <- + ChildIds _ childThreadId <- spawn scope opts \childId masking -> do result <- unexceptionalTry (masking action) case result of @@ -427,7 +436,7 @@ forkTryWith scope opts action = do propagateException :: Scope -> Tid -> SomeException -> UnexceptionalIO () propagateException Scope {childExceptionVar, parentThreadId, statusVar} childId exception = UnexceptionalIO (readTVarIO statusVar) >>= \case - Closing -> tryPutChildExceptionVar -- (A) / (B) + Closing -> tryPutChildExceptionVar -- (A) or (B), we don't care which status -> assert (status >= 0) loop -- we know status is Open here where loop :: UnexceptionalIO ()