diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java b/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java index 6a8118c6..4bc67d20 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/LaunchPlan.java @@ -19,6 +19,7 @@ import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.Map; +import java.util.Optional; import javax.annotation.Nullable; /** User-provided launch plan definition and configuration values. */ @@ -40,6 +41,11 @@ public abstract class LaunchPlan { */ public abstract Map defaultInputs(); + /** + * Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. + */ + public abstract Optional maxParallelism(); + @Nullable public abstract CronSchedule cronSchedule(); @@ -64,6 +70,8 @@ public abstract static class Builder { public abstract Builder cronSchedule(CronSchedule cronSchedule); + public abstract Builder maxParallelism(Optional maxParallelism); + public abstract LaunchPlan build(); } } diff --git a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java index a4471eb0..a6911d9a 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/FibonacciLaunchPlan.java @@ -18,6 +18,7 @@ import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; +import java.util.Optional; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkLaunchPlan; @@ -53,10 +54,20 @@ public FibonacciLaunchPlan() { .withName("FibonacciWorkflowLaunchPlan3") .withDefaultInput("fib0", 0L) .withDefaultInput("fib1", 1L)); + + // Register launch plan with fixed inputs and maxParallelism of 10 + registerLaunchPlan( + SdkLaunchPlan.of(new FibonacciWorkflow()) + .withName("FibonacciWorkflowLaunchPlan4") + .withFixedInputs( + JacksonSdkType.of(Input.class), + Input.create(SdkBindingDataFactory.of(0), SdkBindingDataFactory.of(1))) + .withMaxParallelism(Optional.of(10))); } @AutoValue abstract static class Input { + abstract SdkBindingData fib0(); abstract SdkBindingData fib1(); diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java index 7e234ffa..3424a286 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlan.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import javax.annotation.Nullable; import org.flyte.api.v1.Literal; @@ -81,6 +82,10 @@ public abstract class SdkLaunchPlan { @Nullable public abstract SdkCronSchedule cronSchedule(); + /** Returns the max parallelism of the launch plan. */ + @Nullable + public abstract Optional maxParallelism(); + /** * Creates a launch plan for specified {@link SdkLaunchPlan} with default naming, no inputs and no * schedule. The default launch plan name is {@link SdkWorkflow#getName()}. New name, inputs and @@ -322,6 +327,16 @@ public SdkLaunchPlan withDefaultInput(SdkType type, T value) { v -> createParameter(v.getValue(), literalMap.get(v.getKey()))))); } + /** + * @param maxParallelism Optional Integer for the max parallelism (cannot be negative). Default + * Value: Empty, it will default to what's set in the Flyte Platform. 0: It will try to use as + * much as allowed. + * @return the new launch plan + */ + public SdkLaunchPlan withMaxParallelism(Optional maxParallelism) { + return withMaxParallelism0(maxParallelism); + } + private SdkLaunchPlan withDefaultInputs0(Map newDefaultInputs) { verifyNonEmptyWorkflowInput(newDefaultInputs, "default"); @@ -336,6 +351,17 @@ private SdkLaunchPlan withDefaultInputs0(Map newDefaultInputs return toBuilder().defaultInputs(newCompleteDefaultInputs).build(); } + private SdkLaunchPlan withMaxParallelism0(Optional maxParallelism) { + if (maxParallelism.isPresent() && maxParallelism.get() < 0) { + String message = + String.format( + "invalid max parallelism %s, expected a positive integer", maxParallelism.get()); + throw new IllegalArgumentException(message); + } + + return toBuilder().maxParallelism(maxParallelism).build(); + } + private Map mergeInputs( Map oldInputs, Map newInputs, String inputType) { Map newCompleteInputs = new LinkedHashMap<>(oldInputs); @@ -388,7 +414,8 @@ static Builder builder() { return new AutoValue_SdkLaunchPlan.Builder() .fixedInputs(Collections.emptyMap()) .defaultInputs(Collections.emptyMap()) - .workflowInputTypeMap(Collections.emptyMap()); + .workflowInputTypeMap(Collections.emptyMap()) + .maxParallelism(Optional.empty()); } abstract Builder toBuilder(); @@ -414,6 +441,8 @@ abstract static class Builder { abstract Builder workflowInputTypeMap(Map workflowInputTypeMap); + abstract Builder maxParallelism(Optional maxParallelism); + abstract SdkLaunchPlan build(); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java index 9913b18d..ac0f121c 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLaunchPlanRegistrar.java @@ -96,7 +96,8 @@ Map load( .name(sdkLaunchPlan.name()) .workflowId(getWorkflowIdentifier(sdkLaunchPlan)) .fixedInputs(sdkLaunchPlan.fixedInputs()) - .defaultInputs(sdkLaunchPlan.defaultInputs()); + .defaultInputs(sdkLaunchPlan.defaultInputs()) + .maxParallelism(sdkLaunchPlan.maxParallelism()); if (sdkLaunchPlan.cronSchedule() != null) { builder.cronSchedule(getCronSchedule(sdkLaunchPlan.cronSchedule())); diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java index 769b5d53..ef1498db 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanRegistrarTest.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import org.flyte.api.v1.CronSchedule; import org.flyte.api.v1.LaunchPlan; import org.flyte.api.v1.LaunchPlanIdentifier; @@ -137,6 +138,34 @@ void shouldTestLaunchPlansWithCronSchedule() { hasEntry(expectedIdentifierWithOffset, planWithOffset))); } + @Test + void shouldTestLaunchPlansWithMaxParallelism() { + Map launchPlans = + registrar.load(ENV, singletonList(new TestRegistryWithMaxParallelism())); + + LaunchPlanIdentifier expectedIdentifierWithOffset = + LaunchPlanIdentifier.builder() + .project("project") + .domain("domain") + .name("TestPlanScheduleWithMaxParallelism") + .version("version") + .build(); + + LaunchPlan planWithOffset = + LaunchPlan.builder() + .name("TestPlanScheduleWithMaxParallelism") + .workflowId( + PartialWorkflowIdentifier.builder() + .name("org.flyte.flytekit.SdkLaunchPlanRegistrarTest$TestWorkflow") + .build()) + .fixedInputs(Collections.emptyMap()) + .defaultInputs(Collections.emptyMap()) + .maxParallelism(Optional.of(10)) + .build(); + + assertThat(launchPlans, allOf(hasEntry(expectedIdentifierWithOffset, planWithOffset))); + } + @Test void shouldRejectLoadingLaunchPlanDuplicatesInSameRegistry() { IllegalArgumentException exception = @@ -208,6 +237,17 @@ public List getLaunchPlans() { } } + public static class TestRegistryWithMaxParallelism implements SdkLaunchPlanRegistry { + + @Override + public List getLaunchPlans() { + return Arrays.asList( + SdkLaunchPlan.of(new TestWorkflow()) + .withName("TestPlanScheduleWithMaxParallelism") + .withMaxParallelism(Optional.of(10))); + } + } + public static class TestWorkflow extends SdkWorkflow { public TestWorkflow() { diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java index bace2a7f..815c36bc 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkLaunchPlanTest.java @@ -31,6 +31,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Map; +import java.util.Optional; import java.util.function.Consumer; import java.util.stream.Stream; import org.flyte.api.v1.Literal; @@ -91,6 +92,14 @@ void shouldCreateLaunchPlanWithCronSchedule() { assertThat(plan.cronSchedule().offset(), equalTo(Duration.ofHours(1))); } + @Test + void shouldCreateLaunchPlanWithMaxParallelism() { + SdkLaunchPlan plan = SdkLaunchPlan.of(new TestWorkflow()).withMaxParallelism(Optional.of(123)); + + assertThat(plan.maxParallelism(), notNullValue()); + assertThat(plan.maxParallelism().get(), equalTo(123)); + } + @Test void shouldAddFixedInputs() { Instant now = Instant.now(); diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/IdentifierRewrite.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/IdentifierRewrite.java index 43439028..cfc35d80 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/IdentifierRewrite.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/IdentifierRewrite.java @@ -112,6 +112,7 @@ LaunchPlan apply(LaunchPlan launchPlan) { .defaultInputs(launchPlan.defaultInputs()) .workflowId(apply(launchPlan.workflowId())) .cronSchedule(launchPlan.cronSchedule()) + .maxParallelism(launchPlan.maxParallelism()) .build(); } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index bddca2a3..957443c1 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -113,6 +113,7 @@ /** Utility to serialize between flytekit-api and flyteidl proto. */ @SuppressWarnings("PreferJavaTimeOverload") public class ProtoUtil { + public static final String RUNTIME_FLAVOR = "java"; public static final String RUNTIME_VERSION = "0.0.1"; @@ -717,6 +718,8 @@ static LaunchPlanOuterClass.LaunchPlanSpec serialize(LaunchPlan launchPlan) { .setFixedInputs(ProtoUtil.serialize(launchPlan.fixedInputs())) .setDefaultInputs(ProtoUtil.serializeParameters(launchPlan.defaultInputs())); + launchPlan.maxParallelism().ifPresent(specBuilder::setMaxParallelism); + if (launchPlan.cronSchedule() != null) { ScheduleOuterClass.Schedule schedule = ProtoUtil.serialize(launchPlan.cronSchedule()); specBuilder.setEntityMetadata( diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java index 93344b32..ca728970 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java @@ -54,6 +54,7 @@ import java.time.Duration; import java.util.Arrays; import java.util.Collections; +import java.util.Optional; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.CronSchedule; @@ -219,6 +220,7 @@ public void shouldPropagateLaunchPlanToStub() { LaunchPlan launchPlan = LaunchPlan.builder() .workflowId(wfIdentifier) + .maxParallelism(Optional.of(20)) .name(LP_NAME) .fixedInputs( Collections.singletonMap( @@ -249,6 +251,7 @@ public void shouldPropagateLaunchPlanToStub() { .setSpec( LaunchPlanOuterClass.LaunchPlanSpec.newBuilder() .setWorkflowId(newIdentifier(ResourceType.WORKFLOW, WF_NAME, WF_VERSION)) + .setMaxParallelism(20) .setFixedInputs( Literals.LiteralMap.newBuilder() .putLiterals( diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java index 0bc3b76c..040a22e9 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java @@ -25,6 +25,7 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; import static java.util.Collections.singletonMap; +import static org.flyte.jflyte.utils.ProtoUtil.serialize; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -38,6 +39,7 @@ import com.google.protobuf.ListValue; import com.google.protobuf.NullValue; import com.google.protobuf.Value; +import flyteidl.admin.LaunchPlanOuterClass.LaunchPlanSpec; import flyteidl.admin.ScheduleOuterClass; import flyteidl.core.Condition; import flyteidl.core.DynamicJob; @@ -56,6 +58,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; import org.flyte.api.v1.Binary; import org.flyte.api.v1.Binding; @@ -75,6 +78,7 @@ import org.flyte.api.v1.IfBlock; import org.flyte.api.v1.IfElseBlock; import org.flyte.api.v1.KeyValuePair; +import org.flyte.api.v1.LaunchPlan; import org.flyte.api.v1.LaunchPlanIdentifier; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; @@ -989,6 +993,27 @@ public void shouldSerializeCronSchedule() { .build())); } + @Test + public void shouldSerializeLaunchPlanMaxParallelism() { + Optional maxParallelism = Optional.of(10); + LaunchPlan launchPlan = + LaunchPlan.builder() + .name("name") + .workflowId( + PartialWorkflowIdentifier.builder() + .project("test-project") + .domain("test-domain") + .version("a-version") + .name("name") + .build()) + .maxParallelism(maxParallelism) + .build(); + + LaunchPlanSpec res = serialize(launchPlan); + + assertThat(res.getMaxParallelism(), equalTo(10)); + } + @Test public void shouldSerializeCronScheduleNoOffset() { CronSchedule cronSchedule = CronSchedule.builder().schedule("* * */5 * *").build(); @@ -1252,7 +1277,7 @@ void shouldSerializeContainerWithResources() { Resources.ResourceName.CPU, "8", Resources.ResourceName.MEMORY, "32G")) .build()); - Tasks.Container actual = ProtoUtil.serialize(container); + Tasks.Container actual = serialize(container); assertThat( actual, @@ -1296,7 +1321,7 @@ void shouldAcceptResourcesWithValidQuantities(String quantity) { .limits(ImmutableMap.of(Resources.ResourceName.CPU, quantity)) .build()); - Tasks.Container actual = ProtoUtil.serialize(container); + Tasks.Container actual = serialize(container); assertThat( actual, @@ -1326,7 +1351,7 @@ void shouldRejectResourcesWithInvalidQuantities(String quantity) { .build()); IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> ProtoUtil.serialize(container)); + assertThrows(IllegalArgumentException.class, () -> serialize(container)); assertEquals( "Resource requests [CPU] has invalid quantity: " + quantity, exception.getMessage());