Skip to content

Commit

Permalink
[joern-slice] Fixed Parallelism (#4010)
Browse files Browse the repository at this point in the history
* Utility for General Purpose Parallelism
Following the type recovery generator parallelism fix, as well as noting that faulty parallelism exists elsewhere such as in the slicing, I've created this general purpose concurrency tool as an easy means to bootstrap effect concurrency for execution order independent tasks.

* Allowed execution context to be defined as an implicit arg

* [joern-slice] Fixed Parallelism
Following #4009, I have replaced the parallelism in the slicing with this utility.

* Brought new general purpose concurrent util in

* Increased number of tasks to ensure parallelism better over serial
  • Loading branch information
DavidBakerEffendi authored Jan 10, 2024
1 parent 126f942 commit abbb77f
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,27 +1,40 @@
package io.joern.dataflowengineoss.slicing

import io.joern.dataflowengineoss.language.*
import io.joern.x2cpg.utils.ConcurrentTaskUtil
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.PropertyNames
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.semanticcpg.language.*
import org.slf4j.LoggerFactory

import java.util.concurrent.{Callable, Executors}
import java.util.concurrent.Callable
import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success}

object DataFlowSlicing {

implicit val resolver: ICallResolver = NoResolve
private val logger = LoggerFactory.getLogger(getClass)

def calculateDataFlowSlice(cpg: Cpg, config: DataFlowConfig): Option[DataFlowSlice] = {
implicit val implicitConfig: DataFlowConfig = config

val exec = poolFromConfig(config)
(config.fileFilter match {
val tasks = (config.fileFilter match {
case Some(fileName) => cpg.file.nameExact(fileName).method.call
case None => cpg.call
}).method.withMethodNameFilter.withMethodParameterFilter.withMethodAnnotationFilter.call.withExternalCalleeFilter
.map(c => exec.submit(new TrackDataFlowTask(config, c)))
.flatMap(_.get())
.map(c => () => new TrackDataFlowTask(config, c).call())
.iterator

ConcurrentTaskUtil
.runUsingThreadPool(tasks, config.parallelism.getOrElse(Runtime.getRuntime.availableProcessors()))
.flatMap {
case Success(slice) => slice
case Failure(e) =>
logger.warn("Exception encountered during slicing task", e)
None
}
.reduceOption { (a, b) => DataFlowSlice(a.nodes ++ b.nodes, a.edges ++ b.edges) }
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
package io.joern.dataflowengineoss.slicing

import io.joern.x2cpg.utils.ConcurrentTaskUtil
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{Operators, PropertyNames}
import io.shiftleft.semanticcpg.language.*
import org.slf4j.LoggerFactory

import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.regex.Pattern
import scala.collection.concurrent.TrieMap
import scala.util.Try
import scala.util.{Failure, Success, Try}

/** A utility for slicing based off of usage references for identifiers and parameters. This is mainly tested around
* JavaScript CPGs.
*/
object UsageSlicing {

private val logger = LoggerFactory.getLogger(getClass)
private val resolver = NoResolve
private val constructorTypeMatcher = Pattern.compile(".*new (\\w+)\\(.*")
private val excludeOperatorCalls = new AtomicBoolean(false)
Expand All @@ -38,30 +41,30 @@ object UsageSlicing {

def typeMap = TrieMap.from(cpg.typeDecl.map(f => (f.name, f.fullName)).toMap)

val exec = poolFromConfig(config)
try {
val slices = usageSlices(exec, cpg, declarations, typeMap)
val userDefTypes = userDefinedTypes(cpg)
ProgramUsageSlice(slices, userDefTypes)
} finally {
exec.shutdown()
}
val slices = usageSlices(cpg, declarations, typeMap)
val userDefTypes = userDefinedTypes(cpg)
ProgramUsageSlice(slices, userDefTypes)
}

import io.shiftleft.semanticcpg.codedumper.CodeDumper.dump

private def usageSlices(
exec: ExecutorService,
cpg: Cpg,
declarations: List[Declaration],
typeMap: TrieMap[String, String]
)(implicit config: UsagesConfig): List[MethodUsageSlice] = {
private def usageSlices(cpg: Cpg, declarations: List[Declaration], typeMap: TrieMap[String, String])(implicit
config: UsagesConfig
): List[MethodUsageSlice] = {
val language = cpg.metaData.language.headOption
val root = cpg.metaData.root.headOption
declarations
val tasks = declarations
.filter(a => atLeastNCalls(a, config.minNumCalls) && !a.name.startsWith("_tmp_"))
.map(a => exec.submit(new TrackUsageTask(cpg, a, typeMap)))
.flatMap(_.get)
.map(a => () => new TrackUsageTask(cpg, a, typeMap).call())
.iterator
ConcurrentTaskUtil
.runUsingThreadPool(tasks, config.parallelism.getOrElse(Runtime.getRuntime.availableProcessors()))
.flatMap {
case Success(slice) => slice
case Failure(e) =>
logger.warn("Exception encountered during slicing task", e)
None
}
.groupBy { case (scope, _) => scope }
.view
.sortBy(_._1.fullName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,6 @@ package object slicing {
excludeMethodSource: Boolean = false
) extends BaseConfig[UsagesConfig]

def poolFromConfig(config: BaseConfig[_]): ExecutorService = config.parallelism match
case Some(parallelism) if parallelism == 1 => Executors.newSingleThreadExecutor()
case Some(parallelism) if parallelism > 1 => Executors.newWorkStealingPool(parallelism)
case _ => Executors.newWorkStealingPool()

/** Adds extensions to modify a call traversal based on config options.
*/
implicit class CallFilterExt(trav: Iterator[Call]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ConcurrentTaskUtilTests extends AnyWordSpec with Matchers {

"perform better against a large number of 'cheap' operations using a thread pool" in {
assumeMultipleProcessors
def problem = Iterator.fill(500)(() => Thread.sleep(1))
def problem = Iterator.fill(1000)(() => Thread.sleep(1))

val parStart = System.nanoTime()
ConcurrentTaskUtil.runUsingThreadPool(problem)
Expand Down

0 comments on commit abbb77f

Please sign in to comment.