Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
scottsand-db committed Sep 20, 2024
1 parent 0cbf406 commit e44a96f
Showing 1 changed file with 76 additions and 56 deletions.
132 changes: 76 additions & 56 deletions project/TestParallelization.scala
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import TestParallelization.SimpleHashStrategy.HIGH_DURATION_TEST_SUITES
import sbt.Keys._
import sbt._

// scalastyle:off println

object TestParallelization {

lazy val numShardsOpt = sys.env.get("NUM_SHARDS").map(_.toInt)
lazy val testParallelismOpt = sys.env.get("TEST_PARALLELISM_COUNT").map(_.toInt)

lazy val settings = {
val parallelismCount = sys.env.get("TEST_PARALLELISM_COUNT")
if (parallelismCount.exists(_.toInt > 1)) {
if (numShardsOpt.exists(_ > 1) && testParallelismOpt.exists(_ > 1)) {
customTestGroupingSettings ++ simpleGroupingStrategySettings
} else {
Seq.empty[Setting[_]]
}
}

/**
* Replace the default value for Test / testGrouping settingKey
* and set it to a new value calculated by using the custom Task
* [[testGroupingStrategy]]. Adding these settings to the build
* will require to separately provide a value for the TaskKey
* [[testGroupingStrategy]]
* Replace the default value for Test / testGrouping settingKey and set it to a new value
* calculated by using the custom Task [[testGroupingStrategy]].
*
* Adding these settings to the build will require us to separately provide a value for the
* TaskKey [[testGroupingStrategy]]
*/
lazy val customTestGroupingSettings = {
Seq(
Expand All @@ -37,22 +38,16 @@ object TestParallelization {
group =>
logger.info(s"${group.name} contains ${group.tests.size} tests")
}

println(groupingStrategy)
sys.exit()
groups
}
)
}

/**
* Sets the Test / testGroupingStrategy Task to an instance of the
* SimpleHashStrategy
*/
/** Sets the Test / testGroupingStrategy Task to an instance of the SimpleHashStrategy */
lazy val simpleGroupingStrategySettings = Seq(
Test / forkTestJVMCount := {
sys.env.get("TEST_PARALLELISM_COUNT").map(_.toInt)
.getOrElse(java.lang.Runtime.getRuntime.availableProcessors)
testParallelismOpt.getOrElse(java.lang.Runtime.getRuntime.availableProcessors)
},
Test / shardId := {
sys.env.get("SHARD_ID").map(_.toInt)
Expand All @@ -61,7 +56,7 @@ object TestParallelization {
val groupsCount = (Test / forkTestJVMCount).value
val shard = (Test / shardId).value
val baseJvmDir = baseDirectory.value
SimpleHashStrategy(groupsCount, baseJvmDir, shard, defaultForkOptions.value)
GreedyHashStrategy(groupsCount, baseJvmDir, shard, defaultForkOptions.value)
},
Test / parallelExecution := true,
Global / concurrentRestrictions := {
Expand Down Expand Up @@ -97,50 +92,50 @@ object TestParallelization {
/**
* Base trait to group tests.
*
* By default, SBT will run all tests as if they belong to a single group,
* but allows tests to be grouped. Setting [[sbt.Keys.testGrouping]] to
* a list of groups replaces the default single-group definition.
* By default, SBT will run all tests as if they belong to a single group, but allows tests to be
* grouped. Setting [[sbt.Keys.testGrouping]] to a list of groups replaces the default
* single-group definition.
*
* When creating an instance of [[sbt.Tests.Group]] it is possible to specify
* an [[sbt.Tests.TestRunPolicy]]: this parameter can be used to use multiple
* subprocesses for test execution
* When creating an instance of [[sbt.Tests.Group]] it is possible to specify an
* [[sbt.Tests.TestRunPolicy]]: this parameter can be used to use multiple subprocesses for test
* execution
*/
sealed trait GroupingStrategy {

/**
* Adds an [[sbt.TestDefinition]] to this GroupingStrategy and
* returns an updated Grouping Strategy
* Adds an [[sbt.TestDefinition]] to this GroupingStrategy and returns an updated Grouping
* Strategy
*/
def add(testDefinition: TestDefinition): GroupingStrategy

/**
* Returns the test groups built from this GroupingStrategy
*/
/** Returns the test groups built from this GroupingStrategy */
def testGroups: List[Tests.Group]
}

class SimpleHashStrategy private(
groups: Map[Int, Tests.Group],
shardIdOpt: Option[Int],
class GreedyHashStrategy private(
groups: Map[Int, Tests.Group], // TEST_PARALLELISM_COUNT
shardIdOpt: Option[Int], // which shard ID is this out of NUM_SHARDS
var groupDurations: Array[Double]
) extends GroupingStrategy {

import TestParallelization.GreedyHashStrategy.HIGH_DURATION_TEST_SUITES

lazy val testGroups = groups.values.toList
val groupCount = groups.size

override def add(testDefinition: TestDefinition): GroupingStrategy = {
def standardGroupAssignment(): SimpleHashStrategy = {
def assignTestToGroupByHash(): GreedyHashStrategy = {
val groupIdx = math.abs(testDefinition.name.hashCode % groupCount)
val currentGroup = groups(groupIdx)
val updatedGroup = currentGroup.withTests(
currentGroup.tests :+ testDefinition
)
new SimpleHashStrategy(groups + (groupIdx -> updatedGroup), shardIdOpt, groupDurations)
new GreedyHashStrategy(groups + (groupIdx -> updatedGroup), shardIdOpt, groupDurations)
}

// Case 0: If are not using sharding, then just assign to a random group using hash
if (shardIdOpt.isEmpty || numShardsOpt.isEmpty) {
return standardGroupAssignment()
return assignTestToGroupByHash()
}

val shardId = shardIdOpt.get
Expand All @@ -153,29 +148,29 @@ object TestParallelization {

// We are using sharding. Now we just need to determine if this test suite belongs to this
// shard. There are two cases:
// 1) it is a high-duration test. It belongs to this shard if the test suite's index in
// HIGH_DURATION_TEST_SUITES % numShards equals shardId
// 2) it is not a high-duration test. It belongs to this shard if the hash of the test suite
// name % numShards equals shardId
// Case 1: it is a high-duration test. It belongs to this shard if the test suite's index in
// HIGH_DURATION_TEST_SUITES % numShards equals shardId
// Case 2: it is not a high-duration test. It belongs to this shard if the hash of the test
// suite name % numShards equals shardId

val testSuiteName = testDefinition.name

println(s"Trying to assign test suite: $testSuiteName. This is shardId: $shardId.")
// println(s"Trying to assign test suite: $testSuiteName. This is shardId: $shardId.")

val highDurationTestIndex = HIGH_DURATION_TEST_SUITES.indexWhere(_._1 == testSuiteName)

println(s"highDurationTestIndex: $highDurationTestIndex")
// println(s"highDurationTestIndex: $highDurationTestIndex")

if (highDurationTestIndex >= 0 && highDurationTestIndex % numShardsOpt.get == shardId) {
println(s"[High Duration] Assigning test suite $testSuiteName to the current shard $shardId")
// println(s"[High Duration] Assigning test suite $testSuiteName to the current shard $shardId")

val estimatedDuration = HIGH_DURATION_TEST_SUITES(highDurationTestIndex)._2

println(s"Test suite $testSuiteName has estimated duration $estimatedDuration")
// println(s"Test suite $testSuiteName has estimated duration $estimatedDuration")

val groupIdxWithLowestDuration = groupDurations.zipWithIndex.minBy(_._1)._2

println(s"Assigning test suite $testSuiteName to group $groupIdxWithLowestDuration")
// println(s"Assigning test suite $testSuiteName to group $groupIdxWithLowestDuration")

groupDurations(groupIdxWithLowestDuration) += estimatedDuration

Expand All @@ -184,16 +179,16 @@ object TestParallelization {
currentGroup.tests :+ testDefinition
)

new SimpleHashStrategy(groups + (groupIdxWithLowestDuration -> updatedGroup), shardIdOpt, groupDurations)
new GreedyHashStrategy(groups + (groupIdxWithLowestDuration -> updatedGroup), shardIdOpt, groupDurations)
} else if (math.abs(testDefinition.name.hashCode % numShardsOpt.get) == shardId) {
println(s"[Normal] Assigning test suite $testSuiteName to the current shard $shardId")
// println(s"[Normal] Assigning test suite $testSuiteName to the current shard $shardId")

standardGroupAssignment()
assignTestToGroupByHash()
} else {
println(s"NOT assigning test suite $testSuiteName to the current shard $shardId")
// println(s"NOT assigning test suite $testSuiteName to the current shard $shardId")

// If not assigned to this shard, just return the unchanged strategy
new SimpleHashStrategy(groups, shardIdOpt, groupDurations)
new GreedyHashStrategy(groups, shardIdOpt, groupDurations)
}
}

Expand All @@ -207,11 +202,11 @@ object TestParallelization {
f" Group $groupIndex: $groupDuration%.2f minutes"
}.mkString("\n")

s"$shardInfo" + s"Group Durations:\n$groupInfo"
s"$shardInfo" + s"Estimated Group Durations:\n$groupInfo"
}
}

object SimpleHashStrategy {
object GreedyHashStrategy {

val HIGH_DURATION_TEST_SUITES: List[(String, Double)] = List(
("org.apache.spark.sql.delta.MergeIntoDVsWithPredicatePushdownCDCSuite", 36.09),
Expand All @@ -234,7 +229,36 @@ object TestParallelization {
("org.apache.spark.sql.delta.DeltaSourceLargeLogSuite", 5.61),
("org.apache.spark.sql.delta.stats.DataSkippingDeltaV1NameColumnMappingSuite", 5.43),
("org.apache.spark.sql.delta.GenerateIdentityValuesSuite", 5.4),
("org.apache.spark.sql.delta.commands.backfill.RowTrackingBackfillConflictsSuite", 5.02)
("org.apache.spark.sql.delta.commands.backfill.RowTrackingBackfillConflictsSuite", 5.02),
("org.apache.spark.sql.delta.ImplicitStreamingMergeCastingSuite", 4.77),
("org.apache.spark.sql.delta.DeltaVacuumWithCoordinatedCommitsBatch100Suite", 4.73),
("org.apache.spark.sql.delta.CoordinatedCommitsBatchBackfill1DeltaLogSuite", 4.64),
("org.apache.spark.sql.delta.DeltaLogSuite", 4.6),
("org.apache.spark.sql.delta.IdentityColumnIngestionScalaSuite", 4.36),
("org.apache.spark.sql.delta.DeltaVacuumSuite", 4.22),
("org.apache.spark.sql.delta.columnmapping.RemoveColumnMappingCDCSuite", 4.12),
("org.apache.spark.sql.delta.DeltaSuite", 4.05),
("org.apache.spark.sql.delta.UpdateSQLSuite", 3.99),
("org.apache.spark.sql.delta.typewidening.TypeWideningInsertSchemaEvolutionSuite", 3.92),
("org.apache.spark.sql.delta.cdc.DeleteCDCSuite", 3.9),
("org.apache.spark.sql.delta.CoordinatedCommitsBatchBackfill100DeltaLogSuite", 3.86),
("org.apache.spark.sql.delta.rowid.UpdateWithRowTrackingCDCSuite", 3.83),
("org.apache.spark.sql.delta.expressions.HilbertIndexSuite", 3.75),
("org.apache.spark.sql.delta.DeltaProtocolVersionSuite", 3.71),
("org.apache.spark.sql.delta.CoordinatedCommitsBatchBackfill2DeltaLogSuite", 3.68),
("org.apache.spark.sql.delta.CheckpointsWithCoordinatedCommitsBatch100Suite", 3.59),
("org.apache.spark.sql.delta.ConvertToDeltaScalaSuite", 3.59),
("org.apache.spark.sql.delta.typewidening.TypeWideningTableFeatureSuite", 3.49),
("org.apache.spark.sql.delta.cdc.UpdateCDCSuite", 3.42),
("org.apache.spark.sql.delta.CloneTableScalaDeletionVectorSuite", 3.41),
("org.apache.spark.sql.delta.IdentityColumnSyncScalaSuite", 3.33),
("org.apache.spark.sql.delta.DeleteSQLSuite", 3.31),
("org.apache.spark.sql.delta.CheckpointsWithCoordinatedCommitsBatch2Suite", 3.19),
("org.apache.spark.sql.delta.DeltaSourceIdColumnMappingSuite", 3.18),
("org.apache.spark.sql.delta.rowid.RowTrackingMergeCDFDVSuite", 3.18),
("org.apache.spark.sql.delta.rowid.UpdateWithRowTrackingTableFeatureCDCSuite", 3.12),
("org.apache.spark.sql.delta.UpdateSQLWithDeletionVectorsAndPredicatePushdownSuite", 3.01),
("org.apache.spark.sql.delta.rowid.RowTrackingMergeDVSuite", 2.97)
)

def apply(
Expand All @@ -256,11 +280,7 @@ object TestParallelization {
groupIdx -> group
}

println("CREATING A BRAND NEW SIMPLE HASH STRATEGY")

val initialGroupDurations = Array.fill(groupCount)(0.0)

new SimpleHashStrategy(testGroups.toMap, shardIdOpt, initialGroupDurations)
new GreedyHashStrategy(testGroups.toMap, shardIdOpt, Array.fill(groupCount)(0.0))
}
}
}

0 comments on commit e44a96f

Please sign in to comment.