From a870cbc87537283b373aa9895913cf38070f27cf Mon Sep 17 00:00:00 2001 From: Martin Traverso Date: Mon, 12 Feb 2024 12:07:25 -0800 Subject: [PATCH] Limit the max number of drivers per task in thread-per-driver scheduler --- .../dedicated/ConcurrencyController.java | 7 ++++++- .../executor/dedicated/TaskEntry.java | 4 ++-- .../ThreadPerDriverTaskExecutor.java | 19 ++++++++++++++++--- .../TestSqlTaskManagerThreadPerDriver.java | 3 ++- .../TestThreadPerDriverTaskExecutor.java | 2 +- .../TestIcebergOrcMetricsCollection.java | 2 ++ 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ConcurrencyController.java b/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ConcurrencyController.java index f27fcc56909f..a1e164ea03e9 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ConcurrencyController.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ConcurrencyController.java @@ -19,12 +19,15 @@ class ConcurrencyController { private static final double TARGET_UTILIZATION = 0.5; + private final int maxConcurrency; private int targetConcurrency; - public ConcurrencyController(int initialConcurrency) + public ConcurrencyController(int initialConcurrency, int maxConcurrency) { checkArgument(initialConcurrency > 0, "initial concurrency must be positive"); + checkArgument(initialConcurrency <= maxConcurrency, "initial concurrency must be <= maxConcurrency>"); this.targetConcurrency = initialConcurrency; + this.maxConcurrency = maxConcurrency; } public void update(double utilization, int currentConcurrency) @@ -35,6 +38,8 @@ public void update(double utilization, int currentConcurrency) else if (utilization < TARGET_UTILIZATION && currentConcurrency >= targetConcurrency) { targetConcurrency++; } + + targetConcurrency = Math.min(maxConcurrency, targetConcurrency); } public int targetConcurrency() diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/TaskEntry.java b/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/TaskEntry.java index b35c80d01f6a..a65b0b4b5707 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/TaskEntry.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/TaskEntry.java @@ -62,7 +62,7 @@ class TaskEntry @GuardedBy("this") private final Set running = new HashSet<>(); - public TaskEntry(TaskId taskId, FairScheduler scheduler, VersionEmbedder versionEmbedder, Tracer tracer, int initialConcurrency, DoubleSupplier utilization) + public TaskEntry(TaskId taskId, FairScheduler scheduler, VersionEmbedder versionEmbedder, Tracer tracer, int initialConcurrency, int maxConcurrency, DoubleSupplier utilization) { this.taskId = requireNonNull(taskId, "taskId is null"); this.scheduler = requireNonNull(scheduler, "scheduler is null"); @@ -71,7 +71,7 @@ public TaskEntry(TaskId taskId, FairScheduler scheduler, VersionEmbedder version this.utilization = requireNonNull(utilization, "utilization is null"); this.group = scheduler.createGroup(taskId.toString()); - this.concurrency = new ConcurrencyController(initialConcurrency); + this.concurrency = new ConcurrencyController(initialConcurrency, maxConcurrency); } public TaskId taskId() diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ThreadPerDriverTaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ThreadPerDriverTaskExecutor.java index fdbf73f29802..f089b56b67b4 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ThreadPerDriverTaskExecutor.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/dedicated/ThreadPerDriverTaskExecutor.java @@ -54,6 +54,7 @@ public class ThreadPerDriverTaskExecutor private final FairScheduler scheduler; private final Tracer tracer; private final VersionEmbedder versionEmbedder; + private final int maxDriversPerTask; private final ScheduledThreadPoolExecutor backgroundTasks = new ScheduledThreadPoolExecutor(2); @GuardedBy("this") @@ -65,15 +66,20 @@ public class ThreadPerDriverTaskExecutor @Inject public ThreadPerDriverTaskExecutor(TaskManagerConfig config, Tracer tracer, VersionEmbedder versionEmbedder) { - this(tracer, versionEmbedder, new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker())); + this( + tracer, + versionEmbedder, + new FairScheduler(config.getMaxWorkerThreads(), "SplitRunner-%d", Ticker.systemTicker()), + config.getMaxDriversPerTask()); } @VisibleForTesting - public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler) + public ThreadPerDriverTaskExecutor(Tracer tracer, VersionEmbedder versionEmbedder, FairScheduler scheduler, int maxDriversPerTask) { this.scheduler = scheduler; this.tracer = requireNonNull(tracer, "tracer is null"); this.versionEmbedder = requireNonNull(versionEmbedder, "versionEmbedder is null"); + this.maxDriversPerTask = maxDriversPerTask; } @PostConstruct @@ -104,7 +110,14 @@ public synchronized TaskHandle addTask( OptionalInt maxDriversPerTask) { checkArgument(!closed, "Executor is already closed"); - TaskEntry task = new TaskEntry(taskId, scheduler, versionEmbedder, tracer, initialSplitConcurrency, utilizationSupplier); + TaskEntry task = new TaskEntry( + taskId, + scheduler, + versionEmbedder, + tracer, + initialSplitConcurrency, + maxDriversPerTask.orElse(this.maxDriversPerTask), + utilizationSupplier); tasks.put(taskId, task); return task; } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java index 08327f8c0082..5204010271aa 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManagerThreadPerDriver.java @@ -30,6 +30,7 @@ protected TaskExecutor createTaskExecutor() return new ThreadPerDriverTaskExecutor( Tracing.noopTracer(), testingVersionEmbedder(), - new FairScheduler(8, "Runner-%d", Ticker.systemTicker())); + new FairScheduler(8, "Runner-%d", Ticker.systemTicker()), + Integer.MAX_VALUE); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java b/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java index ff6318e46601..c206a795e64e 100644 --- a/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java +++ b/core/trino-main/src/test/java/io/trino/execution/executor/dedicated/TestThreadPerDriverTaskExecutor.java @@ -121,7 +121,7 @@ public void testYielding() { TestingTicker ticker = new TestingTicker(); FairScheduler scheduler = new FairScheduler(1, "Runner-%d", ticker); - ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(noopTracer(), testingVersionEmbedder(), scheduler); + ThreadPerDriverTaskExecutor executor = new ThreadPerDriverTaskExecutor(noopTracer(), testingVersionEmbedder(), scheduler, Integer.MAX_VALUE); executor.start(); try { diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcMetricsCollection.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcMetricsCollection.java index 1eb5a995a527..6522c0889ea4 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcMetricsCollection.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergOrcMetricsCollection.java @@ -43,6 +43,7 @@ import java.util.Map; import java.util.Optional; +import static io.trino.SystemSessionProperties.INITIAL_SPLITS_PER_NODE; import static io.trino.SystemSessionProperties.MAX_DRIVERS_PER_TASK; import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; import static io.trino.SystemSessionProperties.TASK_MAX_WRITER_COUNT; @@ -70,6 +71,7 @@ protected QueryRunner createQueryRunner() .setSystemProperty(TASK_CONCURRENCY, "1") .setSystemProperty(TASK_MIN_WRITER_COUNT, "1") .setSystemProperty(TASK_MAX_WRITER_COUNT, "1") + .setSystemProperty(INITIAL_SPLITS_PER_NODE, "1") .setSystemProperty(MAX_DRIVERS_PER_TASK, "1") .setCatalogSessionProperty("iceberg", "orc_string_statistics_limit", Integer.MAX_VALUE + "B") .build();