Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchellwrosen committed May 16, 2024
1 parent f9e3350 commit 7eafb6a
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 128 deletions.
13 changes: 7 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
- Replace `array` with `primitive`
- Make calling `cancel` more than once on a recurring timer not enter an infinite loop
- Slightly improve timer insert performance
- Fix minor race condition in `register`

## [0.4.0.1] - 2022-11-05
## [0.4.0.1] - November 5, 2022

- Fix inaccurate haddock on `recurring`

## [0.4.0] - 2022-11-05
## [0.4.0] - November 5, 2022

- Add `create`
- Rename `Data.TimerWheel` to `TimerWheel`
- Swap out `vector` for `array`
- Treat negative delays as 0
- Drop support for GHC < 8.6

## [0.3.0] - 2020-06-18
## [0.3.0] - June 18, 2020

- Add `with`
- Add support for GHC 8.8, GHC 8.10
Expand All @@ -35,11 +36,11 @@
- Remove `InvalidTimerWheelConfig` exception. `error` is used instead
- Remove support for GHC < 8.6

## [0.2.0.1] - 2019-05-19
## [0.2.0.1] - May 19, 2019

- Swap out `ghc-prim` and `primitive` for `vector`

## [0.2.0] - 2019-02-03
## [0.2.0] - February 3, 2019

- Add `destroy` function, for reaping the background thread
- Add `recurring_` function
Expand All @@ -54,6 +55,6 @@ spin forever and peg a CPU
- Rename `new` to `create`
- Make recurring timers more accurate

## [0.1.0] - 2018-07-18
## [0.1.0] - July 18, 2018

- Initial release
117 changes: 62 additions & 55 deletions src/TimerWheel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@ module TimerWheel
register_,
recurring,
recurring_,
reregister,
)
where

import Control.Exception (mask_)
import qualified Data.Atomics as Atomics
import Data.Foldable (for_)
import qualified Data.Map.Strict as Map
import Data.Primitive.Array (MutableArray)
import qualified Data.Primitive.Array as Array
import GHC.Base (RealWorld)
import qualified Ki
import TimerWheel.Internal.Counter (Counter, decrCounter_, incrCounter, incrCounter_, newCounter, readCounter)
import TimerWheel.Internal.Entries (Entries)
import qualified TimerWheel.Internal.Entries as Entries
import TimerWheel.Internal.Micros (Micros (..))
import qualified TimerWheel.Internal.Micros as Micros
import TimerWheel.Internal.Prelude
import TimerWheel.Internal.Timer (Timer (..))
import TimerWheel.Internal.TimerBucket (TimerBucket)
import qualified TimerWheel.Internal.TimerBucket as TimerBucket
import TimerWheel.Internal.Timestamp (Timestamp)
import qualified TimerWheel.Internal.Timestamp as Timestamp

Expand Down Expand Up @@ -90,7 +93,7 @@ import qualified TimerWheel.Internal.Timestamp as Timestamp
--
-- @
data TimerWheel = TimerWheel
{ buckets :: {-# UNPACK #-} !(MutableArray RealWorld Entries),
{ buckets :: {-# UNPACK #-} !(MutableArray RealWorld TimerBucket),
resolution :: {-# UNPACK #-} !Micros,
numTimers :: {-# UNPACK #-} !Counter,
-- A counter to generate unique ints that identify registered actions, so they can be canceled.
Expand All @@ -100,9 +103,6 @@ data TimerWheel = TimerWheel
totalMicros :: {-# UNPACK #-} !Micros
}

-- Internal type alias for readability
type TimerId = Int

-- | Timer wheel config.
--
-- * @spokes@ must be ∈ @[1, maxBound]@, and is set to @1024@ if invalid.
Expand All @@ -124,10 +124,10 @@ create ::
-- |
IO TimerWheel
create scope (Config spokes0 resolution0) = do
buckets <- Array.newArray spokes Entries.empty
buckets <- Array.newArray spokes Map.empty
numTimers <- newCounter
timerIdSupply <- newCounter
Ki.fork_ scope (runTimerReaperThread buckets numTimers resolution)
Ki.fork_ scope (runTimerReaperThread buckets resolution)
pure TimerWheel {buckets, numTimers, resolution, timerIdSupply, totalMicros}
where
spokes = if spokes0 <= 0 then 1024 else spokes0
Expand Down Expand Up @@ -178,14 +178,31 @@ register_ wheel delay action = do
pure ()

registerImpl :: TimerWheel -> Micros -> IO () -> IO (IO Bool)
registerImpl TimerWheel {buckets, numTimers, resolution, timerIdSupply, totalMicros} delay action = do
registerImpl wheel delay action = do
now <- Timestamp.now
incrCounter_ numTimers
timerId <- incrCounter timerIdSupply
let index = timestampToIndex buckets resolution (now `Timestamp.plus` delay)
let c = unMicros (delay `Micros.div` totalMicros)
atomicInsertIntoBucket buckets index (Entries.insert timerId c action)
pure (atomicDeleteFromBucket buckets index timerId)

let timestamp = now `Timestamp.plus` delay
let index = timestampToIndex wheel.buckets wheel.resolution timestamp
let action1 = action >> decrCounter_ wheel.numTimers

timerId <- incrCounter wheel.timerIdSupply

-- Mask so that we don't increment the timer count, then get hit by an async exception before we can put the timer
-- in the wheel
mask_ do
incrCounter_ wheel.numTimers
atomicModifyArray wheel.buckets index (TimerBucket.insert timestamp (Timer timerId action1))

pure do
let loop :: Atomics.Ticket TimerBucket -> IO Bool
loop ticket =
case TimerBucket.delete timestamp timerId (Atomics.peekTicket ticket) of
DidntDelete -> pure False
Deleted bucket -> do
(success, ticket1) <- Atomics.casArrayElem wheel.buckets index ticket bucket
if success then pure True else loop ticket1
ticket0 <- Atomics.readArrayElem wheel.buckets index
loop ticket0

-- | @recurring wheel action delay@ registers an action __@action@__ in timer wheel __@wheel@__ to fire every
-- __@delay@__ seconds (but no more often than __@wheel@__'s /resolution/).
Expand Down Expand Up @@ -273,19 +290,19 @@ recurring_ wheel (Micros.fromSeconds -> delay) action = do
-- act as if it's still "one bucket ago" at the moment we re-register
-- it.
reregister :: TimerWheel -> Micros -> IO () -> IO (IO Bool)
reregister wheel@TimerWheel {resolution} delay =
reregister wheel delay =
registerImpl
wheel
if resolution > delay
if wheel.resolution > delay
then Micros 0
else delay `Micros.minus` resolution
else delay `Micros.minus` wheel.resolution

-- | Get the number of timers in a timer wheel.
--
-- /O(1)/.
count :: TimerWheel -> IO Int
count TimerWheel {numTimers} =
readCounter numTimers
count wheel =
readCounter wheel.numTimers

-- `timestampToIndex buckets resolution timestamp` figures out which index `timestamp` corresponds to in `buckets`,
-- where each bucket corresponds to `resolution` microseconds.
Expand All @@ -305,71 +322,61 @@ count TimerWheel {numTimers} =
-- 2. Wrap around per the actual length of the array:
--
-- 105329801 `rem` 3 = 2
timestampToIndex :: MutableArray RealWorld Entries -> Micros -> Timestamp -> Int
timestampToIndex :: MutableArray RealWorld bucket -> Micros -> Timestamp -> Int
timestampToIndex buckets resolution timestamp =
-- This downcast is safe because there are at most `maxBound :: Int` buckets (not that anyone would ever have that
-- many...)
fromIntegral @Word64 @Int
(Timestamp.epoch resolution timestamp `rem` fromIntegral @Int @Word64 (Array.sizeofMutableArray buckets))

------------------------------------------------------------------------------------------------------------------------
-- Atomic operations on buckets
-- Atomic operations on arrays

atomicInsertIntoBucket :: MutableArray RealWorld Entries -> Int -> (Entries -> Entries) -> IO ()
atomicInsertIntoBucket buckets index doInsert = do
ticket0 <- Atomics.readArrayElem buckets index
atomicModifyArray :: forall a. MutableArray RealWorld a -> Int -> (a -> a) -> IO ()
atomicModifyArray array index f = do
ticket0 <- Atomics.readArrayElem array index
loop ticket0
where
loop :: Atomics.Ticket a -> IO ()
loop ticket = do
(success, ticket1) <- Atomics.casArrayElem buckets index ticket (doInsert (Atomics.peekTicket ticket))
(success, ticket1) <- Atomics.casArrayElem array index ticket (f (Atomics.peekTicket ticket))
if success then pure () else loop ticket1

atomicDeleteFromBucket :: MutableArray RealWorld Entries -> Int -> TimerId -> IO Bool
atomicDeleteFromBucket buckets index timerId = do
ticket0 <- Atomics.readArrayElem buckets index
loop ticket0
where
loop ticket =
case Entries.delete timerId (Atomics.peekTicket ticket) of
Nothing -> pure False
Just entries1 -> do
(success, ticket1) <- Atomics.casArrayElem buckets index ticket entries1
if success then pure True else loop ticket1

atomicExtractExpiredTimersFromBucket :: MutableArray RealWorld Entries -> Int -> IO [IO ()]
atomicExtractExpiredTimersFromBucket buckets index = do
atomicExtractExpiredTimersFromBucket :: MutableArray RealWorld TimerBucket -> Int -> Timestamp -> IO TimerBucket
atomicExtractExpiredTimersFromBucket buckets index now = do
ticket0 <- Atomics.readArrayElem buckets index
loop ticket0
where
loop :: Atomics.Ticket TimerBucket -> IO TimerBucket
loop ticket
| Entries.null entries = pure []
| Map.null bucket0 = pure bucket0
| otherwise = do
let (expired, entries1) = Entries.partition entries
(success, ticket1) <- Atomics.casArrayElem buckets index ticket entries1
let Pair expired bucket1 = TimerBucket.partition now bucket0
(success, ticket1) <- Atomics.casArrayElem buckets index ticket bucket1
if success then pure expired else loop ticket1
where
entries = Atomics.peekTicket ticket
bucket0 = Atomics.peekTicket ticket

------------------------------------------------------------------------------------------------------------------------
-- Timer reaper thread

runTimerReaperThread :: MutableArray RealWorld Entries -> Counter -> Micros -> IO void
runTimerReaperThread buckets numTimers resolution = do
runTimerReaperThread :: MutableArray RealWorld TimerBucket -> Micros -> IO void
runTimerReaperThread buckets resolution = do
-- Sleep until the very first bucket of timers expires
now <- Timestamp.now
let remainingBucketMicros = resolution `Micros.minus` (now `Timestamp.rem` resolution)
Micros.sleep remainingBucketMicros

loop
(now `Timestamp.plus` remainingBucketMicros)
(now `Timestamp.plus` remainingBucketMicros `Timestamp.plus` resolution)
(timestampToIndex buckets resolution now)
where
loop :: Timestamp -> Int -> IO void
loop !nextTime !index = do
expired <- atomicExtractExpiredTimersFromBucket buckets index
for_ expired \action -> do
action
decrCounter_ numTimers
-- FIXME read over this carefully and document
loop :: Timestamp -> Timestamp -> Int -> IO void
loop !thisTime !nextTime !index = do
expired <- atomicExtractExpiredTimersFromBucket buckets index thisTime
TimerBucket.fire expired
now <- Timestamp.now
when (now < nextTime) (Micros.sleep (nextTime `Timestamp.minus` now))
loop (nextTime `Timestamp.plus` resolution) ((index + 1) `rem` Array.sizeofMutableArray buckets)
Micros.sleep (nextTime `Timestamp.minus` now) -- it's ok if now > nextTime; that'll return immediately
loop nextTime (nextTime `Timestamp.plus` resolution) ((index + 1) `rem` Array.sizeofMutableArray buckets)
63 changes: 0 additions & 63 deletions src/TimerWheel/Internal/Entries.hs

This file was deleted.

15 changes: 14 additions & 1 deletion src/TimerWheel/Internal/Prelude.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module TimerWheel.Internal.Prelude
( Seconds,
( DeleteResult (..),
Pair (..),
Seconds,
module X,
)
where
Expand All @@ -8,9 +10,20 @@ import Control.Monad as X (when)
import Data.Coerce as X (coerce)
import Data.Fixed (E6, Fixed)
import Data.IORef as X (newIORef, readIORef, writeIORef)
import Data.Map as X (Map)
import Data.Word as X (Word64)
import GHC.Generics as X (Generic)

-- | The result of attempting to delete something.
data DeleteResult a
= Deleted !a
| DidntDelete
deriving stock (Functor)

-- | A strict pair.
data Pair a b
= Pair !a !b

-- | A number of seconds, with microsecond precision.
--
-- You can use numeric literals to construct a value of this type, e.g. @0.5@.
Expand Down
14 changes: 14 additions & 0 deletions src/TimerWheel/Internal/Timer.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module TimerWheel.Internal.Timer
( TimerId,
Timer (..),
)
where

data Timer
= Timer
{ id :: !TimerId,
action :: !(IO ())
}

type TimerId =
Int
Loading

0 comments on commit 7eafb6a

Please sign in to comment.