Skip to content

Commit

Permalink
Limit the max number of drivers per task in thread-per-driver scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Feb 13, 2024
1 parent a7b90cd commit a870cbc
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ class ConcurrencyController
{
private static final double TARGET_UTILIZATION = 0.5;

private final int maxConcurrency;
private int targetConcurrency;

public ConcurrencyController(int initialConcurrency)
public ConcurrencyController(int initialConcurrency, int maxConcurrency)
{
checkArgument(initialConcurrency > 0, "initial concurrency must be positive");
checkArgument(initialConcurrency <= maxConcurrency, "initial concurrency must be <= maxConcurrency>");
this.targetConcurrency = initialConcurrency;
this.maxConcurrency = maxConcurrency;
}

public void update(double utilization, int currentConcurrency)
Expand All @@ -35,6 +38,8 @@ public void update(double utilization, int currentConcurrency)
else if (utilization < TARGET_UTILIZATION && currentConcurrency >= targetConcurrency) {
targetConcurrency++;
}

targetConcurrency = Math.min(maxConcurrency, targetConcurrency);
}

public int targetConcurrency()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TaskEntry
@GuardedBy("this")
private final Set<SplitRunner> running = new HashSet<>();

public TaskEntry(TaskId taskId, FairScheduler scheduler, VersionEmbedder versionEmbedder, Tracer tracer, int initialConcurrency, DoubleSupplier utilization)
public TaskEntry(TaskId taskId, FairScheduler scheduler, VersionEmbedder versionEmbedder, Tracer tracer, int initialConcurrency, int maxConcurrency, DoubleSupplier utilization)
{
this.taskId = requireNonNull(taskId, "taskId is null");
this.scheduler = requireNonNull(scheduler, "scheduler is null");
Expand All @@ -71,7 +71,7 @@ public TaskEntry(TaskId taskId, FairScheduler scheduler, VersionEmbedder version
this.utilization = requireNonNull(utilization, "utilization is null");

this.group = scheduler.createGroup(taskId.toString());
this.concurrency = new ConcurrencyController(initialConcurrency);
this.concurrency = new ConcurrencyController(initialConcurrency, maxConcurrency);
}

public TaskId taskId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class ThreadPerDriverTaskExecutor
private final FairScheduler scheduler;
private final Tracer tracer;
private final VersionEmbedder versionEmbedder;
private final int maxDriversPerTask;
private final ScheduledThreadPoolExecutor backgroundTasks = new ScheduledThreadPoolExecutor(2);

@GuardedBy("this")
Expand All @@ -65,15 +66,20 @@ public class ThreadPerDriverTaskExecutor
@Inject
public ThreadPerDriverTaskExecutor(TaskManagerConfig config, Tracer tracer, VersionEmbedder versionEmbedder)
{
this(tracer, versionEmbedder, new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker()));
this(
tracer,
versionEmbedder,
new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker()),
config.getMaxDriversPerTask());
}

@VisibleForTesting
public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler)
public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler, int maxDriversPerTask)
{
this.scheduler = scheduler;
this.tracer = requireNonNull(tracer, "tracer is null");
this.versionEmbedder = requireNonNull(versionEmbedder, "versionEmbedder is null");
this.maxDriversPerTask = maxDriversPerTask;
}

@PostConstruct
Expand Down Expand Up @@ -104,7 +110,14 @@ public synchronized TaskHandle addTask(
OptionalInt maxDriversPerTask)
{
checkArgument(!closed, "Executor is already closed");
TaskEntry task = new TaskEntry(taskId, scheduler, versionEmbedder, tracer, initialSplitConcurrency, utilizationSupplier);
TaskEntry task = new TaskEntry(
taskId,
scheduler,
versionEmbedder,
tracer,
initialSplitConcurrency,
maxDriversPerTask.orElse(this.maxDriversPerTask),
utilizationSupplier);
tasks.put(taskId, task);
return task;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ protected TaskExecutor createTaskExecutor()
return new ThreadPerDriverTaskExecutor(
Tracing.noopTracer(),
testingVersionEmbedder(),
new FairScheduler(8, "Runner-%d", Ticker.systemTicker()));
new FairScheduler(8, "Runner-%d", Ticker.systemTicker()),
Integer.MAX_VALUE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void testYielding()
{
TestingTicker ticker = new TestingTicker();
FairScheduler scheduler = new FairScheduler(1, "Runner-%d", ticker);
ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(noopTracer(), testingVersionEmbedder(), scheduler);
ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(noopTracer(), testingVersionEmbedder(), scheduler, Integer.MAX_VALUE);
executor.start();

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.Map;
import java.util.Optional;

import static io.trino.SystemSessionProperties.INITIAL_SPLITS_PER_NODE;
import static io.trino.SystemSessionProperties.MAX_DRIVERS_PER_TASK;
import static io.trino.SystemSessionProperties.TASK_CONCURRENCY;
import static io.trino.SystemSessionProperties.TASK_MAX_WRITER_COUNT;
Expand Down Expand Up @@ -70,6 +71,7 @@ protected QueryRunner createQueryRunner()
.setSystemProperty(TASK_CONCURRENCY, "1")
.setSystemProperty(TASK_MIN_WRITER_COUNT, "1")
.setSystemProperty(TASK_MAX_WRITER_COUNT, "1")
.setSystemProperty(INITIAL_SPLITS_PER_NODE, "1")
.setSystemProperty(MAX_DRIVERS_PER_TASK, "1")
.setCatalogSessionProperty("iceberg", "orc_string_statistics_limit", Integer.MAX_VALUE + "B")
.build();
Expand Down

0 comments on commit a870cbc

Please sign in to comment.