Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility for General Purpose Parallelism #4009

Merged
merged 7 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package io.joern.x2cpg.utils

import java.util
import java.util.concurrent.{Callable, Executors}
import java.util.stream.{Collectors, StreamSupport}
import java.util.{Collections, Spliterator, Spliterators}
import scala.jdk.CollectionConverters.*
import scala.util.Try

/** A utility for providing out-of-the-box basic concurrent execution for a collection of Scala functions.
*/
object ConcurrentTaskUtil {

private val MAX_POOL_SIZE = Runtime.getRuntime.availableProcessors()

/** Uses a thread pool with a limited number of active threads executing a task at any given point. This is effective
* when tasks may require large amounts of memory, or single tasks are too short lived.
*
* @param tasks
* the tasks to parallelize.
* @param maxPoolSize
* the max pool size to allow for active threads.
* @tparam V
* the output type of each task.
* @return
* an array of the executed tasks as either a success or failure.
*/
def runUsingThreadPool[V](tasks: Iterator[() => V], maxPoolSize: Int = MAX_POOL_SIZE): List[Try[V]] = {
val ex = Executors.newFixedThreadPool(maxPoolSize)
try {
val callables = Collections.list(tasks.map { x =>
new Callable[V] {
override def call(): V = x.apply()
}
}.asJavaEnumeration)
ex.invokeAll(callables).asScala.map(x => Try(x.get())).toList
} finally {
ex.shutdown()
}
}

/** Uses a Spliterator to run a number of tasks in parallel, where any number of threads may be alive at any point.
* This is useful for running a large number of tasks with low memory consumption. Spliterator's default thread pool
* is ForkJoinPool.commonPool().
*
* @param tasks
* the tasks to parallelize.
* @tparam V
* the output type of each task.
* @return
* an array of the executed tasks as either a success or failure.
*/
def runUsingSpliterator[V](tasks: Iterator[() => V]): List[Try[V]] = {
StreamSupport
.stream(Spliterators.spliteratorUnknownSize(tasks.asJava, Spliterator.NONNULL), /* parallel */ true)
.map(task => Try(task.apply()))
.collect(Collectors.toList())
.asScala
.toList
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.joern.x2cpg.utils

import org.scalatest.Assertions
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Success

class ConcurrentTaskUtilTests extends AnyWordSpec with Matchers {

private def assumeMultipleProcessors =
assume(Runtime.getRuntime.availableProcessors() > 1, "!!! Number of available processors not larger than 1 !!!")

"compared to serial execution, concurrent execution" should {

"perform better against a large number of 'expensive' operations using a spliterator" in {
assumeMultipleProcessors
def problem = Iterator.fill(500)(() => Thread.sleep(10))

val parStart = System.nanoTime()
ConcurrentTaskUtil.runUsingSpliterator(problem)
val parTotal = System.nanoTime() - parStart

val serStart = System.nanoTime()
problem.foreach(x => x())
val serTotal = System.nanoTime() - serStart

parTotal should be < serTotal
}

"perform better against a large number of 'cheap' operations using a thread pool" in {
assumeMultipleProcessors
def problem = Iterator.fill(500)(() => Thread.sleep(1))

val parStart = System.nanoTime()
ConcurrentTaskUtil.runUsingThreadPool(problem)
val parTotal = System.nanoTime() - parStart

val serStart = System.nanoTime()
problem.foreach(x => x())
val serTotal = System.nanoTime() - serStart

parTotal should be < serTotal
}
}

"provide the means to let the caller handle unsuccessful operations without propagating an exception" in {
val problem = Iterator(() => "Success!", () => "Success!", () => throw new RuntimeException("Failure!"))
val result = ConcurrentTaskUtil.runUsingThreadPool(problem)
result.count {
case Success(_) => true
case _ => false
} shouldBe 2
}

}
Loading