Skip to content

Commit

Permalink
Merge pull request #3332 from armanbilge/feature/jvm-polling-system
Browse files Browse the repository at this point in the history
Polling system
  • Loading branch information
armanbilge authored Jun 13, 2023
2 parents 9780973 + 58f695f commit e9aeb8c
Show file tree
Hide file tree
Showing 41 changed files with 1,930 additions and 246 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ jobs:
java: graalvm@11
- os: windows-latest
scala: 3.3.0
ci: ciJVM
- os: macos-latest
scala: 3.3.0
ci: ciJVM
- os: windows-latest
scala: 2.12.18
ci: ciJVM
- os: macos-latest
scala: 2.12.18
ci: ciJVM
- ci: ciFirefox
scala: 3.3.0
- ci: ciChrome
Expand Down Expand Up @@ -97,9 +101,6 @@ jobs:
- os: macos-latest
ci: ciNative
scala: 2.12.18
- os: macos-latest
ci: ciNative
scala: 3.3.0
- os: windows-latest
java: graalvm@11
runs-on: ${{ matrix.os }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,13 @@ class WorkStealingBenchmark {
(ExecutionContext.fromExecutor(executor), () => executor.shutdown())
}

val compute = new WorkStealingThreadPool(
val compute = new WorkStealingThreadPool[AnyRef](
256,
"io-compute",
"io-blocker",
60.seconds,
false,
SleepSystem,
_.printStackTrace())

val cancelationCheckThreshold =
Expand Down
23 changes: 16 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ThisBuild / git.gitUncommittedChanges := {
}
}

ThisBuild / tlBaseVersion := "3.5"
ThisBuild / tlBaseVersion := "3.6"
ThisBuild / tlUntaggedAreSnapshots := false

ThisBuild / organization := "org.typelevel"
Expand Down Expand Up @@ -224,8 +224,8 @@ ThisBuild / githubWorkflowBuildMatrixExclusions := {
val windowsAndMacScalaFilters =
(ThisBuild / githubWorkflowScalaVersions).value.filterNot(Set(Scala213)).flatMap { scala =>
Seq(
MatrixExclude(Map("os" -> Windows, "scala" -> scala)),
MatrixExclude(Map("os" -> MacOS, "scala" -> scala)))
MatrixExclude(Map("os" -> Windows, "scala" -> scala, "ci" -> CI.JVM.command)),
MatrixExclude(Map("os" -> MacOS, "scala" -> scala, "ci" -> CI.JVM.command)))
}

val jsScalaFilters = for {
Expand Down Expand Up @@ -254,9 +254,7 @@ ThisBuild / githubWorkflowBuildMatrixExclusions := {

javaFilters ++ Seq(
MatrixExclude(Map("os" -> Windows, "ci" -> ci)),
MatrixExclude(Map("os" -> MacOS, "ci" -> ci, "scala" -> Scala212)),
// keep a native+2.13+macos job
MatrixExclude(Map("os" -> MacOS, "ci" -> ci, "scala" -> Scala3))
MatrixExclude(Map("os" -> MacOS, "ci" -> ci, "scala" -> Scala212))
)
}

Expand Down Expand Up @@ -640,7 +638,10 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.IOFiberConstants.ExecuteRunnableR"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.scope"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"cats.effect.IOFiberConstants.ContStateResult")
"cats.effect.IOFiberConstants.ContStateResult"),
// introduced by #3332, polling system
ProblemFilters.exclude[DirectMissingMethodProblem](
"cats.effect.unsafe.IORuntimeBuilder.this")
) ++ {
if (tlIsScala3.value) {
// Scala 3 specific exclusions
Expand Down Expand Up @@ -824,6 +825,14 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform)
} else Seq()
}
)
.nativeSettings(
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[MissingClassProblem](
"cats.effect.unsafe.PollingExecutorScheduler$SleepTask"),
ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.QueueExecutorScheduler"),
ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.QueueExecutorScheduler$")
)
)
.disablePlugins(JCStressPlugin)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.concurrent.duration.FiniteDuration
// Can you imagine a thread pool on JS? Have fun trying to extend or instantiate
// this class. Unfortunately, due to the explicit branching, this type leaks
// into the shared source code of IOFiber.scala.
private[effect] sealed abstract class WorkStealingThreadPool private ()
private[effect] sealed abstract class WorkStealingThreadPool[P] private ()
extends ExecutionContext {
def execute(runnable: Runnable): Unit
def reportFailure(cause: Throwable): Unit
Expand All @@ -38,12 +38,12 @@ private[effect] sealed abstract class WorkStealingThreadPool private ()
private[effect] def canExecuteBlockingCode(): Boolean
private[unsafe] def liveTraces(): (
Map[Runnable, Trace],
Map[WorkerThread, (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])],
Map[WorkerThread[P], (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])],
Map[Runnable, Trace])
}

private[unsafe] sealed abstract class WorkerThread private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool): Boolean
private[unsafe] sealed abstract class WorkerThread[P] private () extends Thread {
private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean
private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle
private[unsafe] def index: Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
private[effect] sealed class FiberMonitor(
// A reference to the compute pool of the `IORuntime` in which this suspended fiber bag
// operates. `null` if the compute pool of the `IORuntime` is not a `WorkStealingThreadPool`.
private[this] val compute: WorkStealingThreadPool
private[this] val compute: WorkStealingThreadPool[_]
) extends FiberMonitorShared {

private[this] final val BagReferences =
Expand All @@ -69,8 +69,8 @@ private[effect] sealed class FiberMonitor(
*/
def monitorSuspended(fiber: IOFiber[_]): WeakBag.Handle = {
val thread = Thread.currentThread()
if (thread.isInstanceOf[WorkerThread]) {
val worker = thread.asInstanceOf[WorkerThread]
if (thread.isInstanceOf[WorkerThread[_]]) {
val worker = thread.asInstanceOf[WorkerThread[_]]
// Guard against tracking errors when multiple work stealing thread pools exist.
if (worker.isOwnedBy(compute)) {
worker.monitor(fiber)
Expand Down Expand Up @@ -116,14 +116,14 @@ private[effect] sealed class FiberMonitor(
val externalFibers = external.collect(justFibers)
val suspendedFibers = suspended.collect(justFibers)
val workersMapping: Map[
WorkerThread,
WorkerThread[_],
(Thread.State, Option[(IOFiber[_], Trace)], Map[IOFiber[_], Trace])] =
workers.map {
case (thread, (state, opt, set)) =>
val filteredOpt = opt.collect(justFibers)
val filteredSet = set.collect(justFibers)
(thread, (state, filteredOpt, filteredSet))
}
}.toMap

(externalFibers, workersMapping, suspendedFibers)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2020-2023 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package unsafe

abstract class PollingSystem {

/**
* The user-facing interface.
*/
type Api <: AnyRef

/**
* The thread-local data structure used for polling.
*/
type Poller <: AnyRef

def close(): Unit

def makeApi(register: (Poller => Unit) => Unit): Api

def makePoller(): Poller

def closePoller(poller: Poller): Unit

/**
* @param nanos
* the maximum duration for which to block, where `nanos == -1` indicates to block
* indefinitely.
*
* @return
* whether any events were polled
*/
def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean

/**
* @return
* whether poll should be called again (i.e., there are more events to be polled)
*/
def needsPoll(poller: Poller): Boolean

def interrupt(targetThread: Thread, targetPoller: Poller): Unit

}

private object PollingSystem {
type WithPoller[P] = PollingSystem {
type Poller = P
}
}
9 changes: 7 additions & 2 deletions core/jvm/src/main/scala/cats/effect/IOApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ trait IOApp {
*/
protected def runtimeConfig: unsafe.IORuntimeConfig = unsafe.IORuntimeConfig()

protected def pollingSystem: unsafe.PollingSystem =
unsafe.IORuntime.createDefaultPollingSystem()

/**
* Controls the number of worker threads which will be allocated to the compute pool in the
* underlying runtime. In general, this should be no ''greater'' than the number of physical
Expand Down Expand Up @@ -338,11 +341,12 @@ trait IOApp {
import unsafe.IORuntime

val installed = IORuntime installGlobal {
val (compute, compDown) =
val (compute, poller, compDown) =
IORuntime.createWorkStealingComputeThreadPool(
threads = computeWorkerThreadCount,
reportFailure = t => reportFailure(t).unsafeRunAndForgetWithoutCallback()(runtime),
blockedThreadDetectionEnabled = blockedThreadDetectionEnabled
blockedThreadDetectionEnabled = blockedThreadDetectionEnabled,
pollingSystem = pollingSystem
)

val (blocking, blockDown) =
Expand All @@ -352,6 +356,7 @@ trait IOApp {
compute,
blocking,
compute,
List(poller),
{ () =>
compDown()
blockDown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,5 @@ private[effect] abstract class IOCompanionPlatform { this: IO.type =>
*/
def readLine: IO[String] =
Console[IO].readLine

}
36 changes: 36 additions & 0 deletions core/jvm/src/main/scala/cats/effect/Selector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2020-2023 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

import java.nio.channels.SelectableChannel
import java.nio.channels.spi.SelectorProvider

trait Selector {

/**
* The [[java.nio.channels.spi.SelectorProvider]] that should be used to create
* [[java.nio.channels.SelectableChannel]]s that are compatible with this polling system.
*/
def provider: SelectorProvider

/**
* Fiber-block until a [[java.nio.channels.SelectableChannel]] is ready on at least one of the
* designated operations. The returned value will indicate which operations are ready.
*/
def select(ch: SelectableChannel, ops: Int): IO[Int]

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import scala.concurrent.ExecutionContext

private[unsafe] trait FiberMonitorCompanionPlatform {
def apply(compute: ExecutionContext): FiberMonitor = {
if (TracingConstants.isStackTracing && compute.isInstanceOf[WorkStealingThreadPool]) {
val wstp = compute.asInstanceOf[WorkStealingThreadPool]
if (TracingConstants.isStackTracing && compute.isInstanceOf[WorkStealingThreadPool[_]]) {
val wstp = compute.asInstanceOf[WorkStealingThreadPool[_]]
new FiberMonitor(wstp)
} else {
new FiberMonitor(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,36 @@ package cats.effect.unsafe

private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder =>

protected var customPollingSystem: Option[PollingSystem] = None

/**
* Override the default [[PollingSystem]]
*/
def setPollingSystem(system: PollingSystem): IORuntimeBuilder = {
if (customPollingSystem.isDefined) {
throw new RuntimeException("Polling system can only be set once")
}
customPollingSystem = Some(system)
this
}

// TODO unify this with the defaults in IORuntime.global and IOApp
protected def platformSpecificBuild: IORuntime = {
val (compute, computeShutdown) =
customCompute.getOrElse(
IORuntime.createWorkStealingComputeThreadPool(reportFailure = failureReporter))
val (compute, poller, computeShutdown) =
customCompute
.map {
case (c, s) =>
(c, Nil, s)
}
.getOrElse {
val (c, p, s) =
IORuntime.createWorkStealingComputeThreadPool(
pollingSystem =
customPollingSystem.getOrElse(IORuntime.createDefaultPollingSystem()),
reportFailure = failureReporter
)
(c, List(p), s)
}
val xformedCompute = computeTransform(compute)

val (scheduler, schedulerShutdown) = xformedCompute match {
Expand All @@ -36,6 +61,7 @@ private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder
computeShutdown()
blockingShutdown()
schedulerShutdown()
extraPollers.foreach(_._2())
extraShutdownHooks.reverse.foreach(_())
}
val runtimeConfig = customConfig.getOrElse(IORuntimeConfig())
Expand All @@ -44,6 +70,7 @@ private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder
computeTransform(compute),
blockingTransform(blocking),
scheduler,
poller ::: extraPollers.map(_._1),
shutdown,
runtimeConfig
)
Expand Down
Loading

0 comments on commit e9aeb8c

Please sign in to comment.