From d5015fb2204d3e3328dc239f7238d136e2395c1e Mon Sep 17 00:00:00 2001 From: Honnix Date: Mon, 18 Dec 2023 14:43:25 +0100 Subject: [PATCH] Make plugin task first-class citizen (#272) * Revert "Revert "Make plugin task first-class citizen (#268)" (#271)" This reverts commit 3d9ab4e48f58b1f8199c4a36daf15310b282708f. Signed-off-by: Hongxin Liang * Drop isSyncPlugin for now Signed-off-by: Hongxin Liang * Skip staging and merging custom where applicable Signed-off-by: Hongxin Liang * Skip container Signed-off-by: Hongxin Liang * More description of SdkPluginTask Signed-off-by: Hongxin Liang --------- Signed-off-by: Hongxin Liang --- .../java/org/flyte/api/v1/PluginTask.java | 20 +++ .../org/flyte/api/v1/PluginTaskRegistrar.java | 20 +++ .../java/org/flyte/api/v1/TaskTemplate.java | 3 + .../org/flyte/examples/NoopPluginTask.java | 34 ++++ .../org/flyte/flytekit/SdkPluginTask.java | 112 ++++++++++++ .../flytekit/SdkPluginTaskRegistrar.java | 141 +++++++++++++++ .../flytekit/SdkPluginTaskRegistrarTest.java | 161 ++++++++++++++++++ .../flyte/jflyte/utils/ProjectClosure.java | 60 +++++-- .../org/flyte/jflyte/utils/ProtoUtil.java | 28 +-- .../jflyte/utils/ProjectClosureTest.java | 49 ++++++ .../flyte/jflyte/ExecuteDynamicWorkflow.java | 12 +- 11 files changed, 608 insertions(+), 32 deletions(-) create mode 100644 flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java create mode 100644 flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java create mode 100644 flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java create mode 100644 flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java new file mode 100644 index 000000000..af22b3ff4 --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +/** A task that is handled by a Flyte backend plugin instead of run as a container. */ +public interface PluginTask extends Task {} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java new file mode 100644 index 000000000..d2562e4fa --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +/** A registrar that creates {@link PluginTask} instances. */ +public abstract class PluginTaskRegistrar implements Registrar {} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java index 5a48c8900..43a7d60cc 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java @@ -22,6 +22,9 @@ /** * A Task structure that uniquely identifies a task in the system. Tasks are registered as a first * step in the system. + * + *

FIXME: consider offering TaskMetadata instead of having everything in TaskTemplate, see + * https://github.com/flyteorg/flyte/blob/ea72bbd12578d64087221592554fb71c368f8057/flyteidl/protos/flyteidl/core/tasks.proto#L90 */ @AutoValue public abstract class TaskTemplate { diff --git a/flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java b/flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java new file mode 100644 index 000000000..2be34420b --- /dev/null +++ b/flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples; + +import com.google.auto.service.AutoService; +import org.flyte.flytekit.SdkPluginTask; +import org.flyte.flytekit.SdkTypes; + +@AutoService(SdkPluginTask.class) +public class NoopPluginTask extends SdkPluginTask { + + public NoopPluginTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return "noop"; + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java new file mode 100644 index 000000000..88db8ff44 --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java @@ -0,0 +1,112 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.flyte.api.v1.PartialTaskIdentifier; + +/** + * A task that is handled by a Flyte backend plugin instead of run as a container. Note that a + * plugin task template does not have a container defined, neither all the jars captured in + * classpath, so if this is a requirement, one should use SdkRunnableTask overriding run method to + * simply return null. + */ +public abstract class SdkPluginTask extends SdkTransform { + + private final SdkType inputType; + private final SdkType outputType; + + /** + * Called by subclasses passing the {@link SdkType}s for inputs and outputs. + * + * @param inputType type for inputs. + * @param outputType type for outputs. + */ + public SdkPluginTask(SdkType inputType, SdkType outputType) { + this.inputType = inputType; + this.outputType = outputType; + } + + public abstract String getType(); + + @Override + public SdkType getInputType() { + return inputType; + } + + @Override + public SdkType getOutputType() { + return outputType; + } + + /** Specifies custom data that can be read by the backend plugin. */ + public SdkStruct getCustom() { + return SdkStruct.empty(); + } + + /** + * Number of retries. Retries will be consumed when the task fails with a recoverable error. The + * number of retries must be less than or equals to 10. + * + * @return number of retries + */ + public int getRetries() { + return 0; + } + + /** + * Indicates whether the system should attempt to look up this task's output to avoid duplication + * of work. + */ + public boolean isCached() { + return false; + } + + /** Indicates a logical version to apply to this task for the purpose of cache. */ + public String getCacheVersion() { + return null; + } + + /** + * Indicates whether the system should attempt to execute cached instances in serial to avoid + * duplicate work. + */ + public boolean isCacheSerializable() { + return false; + } + + @Override + SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build(); + List errors = + Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap()); + + if (!errors.isEmpty()) { + throw new CompilerException(errors); + } + + return new SdkTaskNode<>( + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java new file mode 100644 index 000000000..939ef7a0e --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java @@ -0,0 +1,141 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import com.google.auto.service.AutoService; +import java.util.HashMap; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TypedInterface; + +/** + * Default implementation of a {@link PluginTaskRegistrar} that discovers {@link SdkPluginTask}s + * implementation via {@link ServiceLoader} mechanism. Plugin tasks implementations must use + * {@code @AutoService(SdkPluginTask.class)} or manually add their fully qualifies name to the + * corresponding file. + * + * @see ServiceLoader + */ +@AutoService(PluginTaskRegistrar.class) +public class SdkPluginTaskRegistrar extends PluginTaskRegistrar { + private static final Logger LOG = Logger.getLogger(SdkPluginTaskRegistrar.class.getName()); + + static { + // enable all levels for the actual handler to pick up + LOG.setLevel(Level.ALL); + } + + private static class PluginTaskImpl implements PluginTask { + private final SdkPluginTask sdkTask; + + private PluginTaskImpl(SdkPluginTask sdkTask) { + this.sdkTask = sdkTask; + } + + @Override + public String getType() { + return sdkTask.getType(); + } + + @Override + public Struct getCustom() { + return sdkTask.getCustom().struct(); + } + + @Override + public TypedInterface getInterface() { + return TypedInterface.builder() + .inputs(sdkTask.getInputType().getVariableMap()) + .outputs(sdkTask.getOutputType().getVariableMap()) + .build(); + } + + @Override + public RetryStrategy getRetries() { + return RetryStrategy.builder().retries(sdkTask.getRetries()).build(); + } + + @Override + public boolean isCached() { + return sdkTask.isCached(); + } + + @Override + public String getCacheVersion() { + return sdkTask.getCacheVersion(); + } + + @Override + public boolean isCacheSerializable() { + return sdkTask.isCacheSerializable(); + } + + @Override + public String getName() { + return sdkTask.getName(); + } + } + + /** + * Load {@link SdkPluginTask}s using {@link ServiceLoader}. + * + * @param env env vars in a map that would be used to pick up the project, domain and version for + * the discovered tasks. + * @param classLoader class loader to use when discovering the task using {@link + * ServiceLoader#load(Class, ClassLoader)} + * @return a map of {@link SdkPluginTask}s by its task identifier. + */ + @Override + @SuppressWarnings("rawtypes") + public Map load(Map env, ClassLoader classLoader) { + ServiceLoader loader = ServiceLoader.load(SdkPluginTask.class, classLoader); + + LOG.fine("Discovering SdkPluginTask"); + + Map tasks = new HashMap<>(); + SdkConfig sdkConfig = SdkConfig.load(env); + + for (SdkPluginTask sdkTask : loader) { + String name = sdkTask.getName(); + TaskIdentifier taskId = + TaskIdentifier.builder() + .domain(sdkConfig.domain()) + .project(sdkConfig.project()) + .name(name) + .version(sdkConfig.version()) + .build(); + LOG.fine(String.format("Discovered [%s]", name)); + + PluginTask task = new PluginTaskImpl<>(sdkTask); + PluginTask previous = tasks.put(taskId, task); + + if (previous != null) { + throw new IllegalArgumentException( + String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous)); + } + } + + return tasks; + } +} diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java new file mode 100644 index 000000000..66a481650 --- /dev/null +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java @@ -0,0 +1,161 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import static org.flyte.flytekit.SdkConfig.DOMAIN_ENV_VAR; +import static org.flyte.flytekit.SdkConfig.PROJECT_ENV_VAR; +import static org.flyte.flytekit.SdkConfig.VERSION_ENV_VAR; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasKey; +import static org.junit.jupiter.api.Assertions.assertAll; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TypedInterface; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class SdkPluginTaskRegistrarTest { + private static final String TASK_TYPE = "test-task"; + private static final Map ENV = + Map.of(PROJECT_ENV_VAR, "project", DOMAIN_ENV_VAR, "domain", VERSION_ENV_VAR, "version"); + + private SdkPluginTaskRegistrar registrar; + + @BeforeEach + void setUp() { + registrar = new SdkPluginTaskRegistrar(); + } + + @Test + void shouldLoadPluginTasksFromDiscoveredRegistries() { + // given + String testTaskName = "org.flyte.flytekit.SdkPluginTaskRegistrarTest$TestTask"; + String otherTestTaskName = "org.flyte.flytekit.SdkPluginTaskRegistrarTest$OtherTestTask"; + TaskIdentifier expectedTestTaskId = + TaskIdentifier.builder() + .project("project") + .domain("domain") + .name(testTaskName) + .version("version") + .build(); + + TypedInterface typedInterface = + TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build(); + + RetryStrategy retries = RetryStrategy.builder().retries(0).build(); + RetryStrategy otherRetries = RetryStrategy.builder().retries(1).build(); + + PluginTask expectedTask = createPluginTask(testTaskName, typedInterface, retries); + + TaskIdentifier expectedOtherTestTaskId = + TaskIdentifier.builder() + .project("project") + .domain("domain") + .name(otherTestTaskName) + .version("version") + .build(); + PluginTask expectedOtherTask = + createPluginTask(otherTestTaskName, typedInterface, otherRetries); + + // when + Map tasks = registrar.load(ENV); + + // then + assertAll( + () -> assertThat(tasks, hasKey(is(expectedTestTaskId))), + () -> assertThat(tasks, hasKey(is(expectedOtherTestTaskId)))); + assertTaskEquals(tasks.get(expectedTestTaskId), expectedTask); + assertTaskEquals(tasks.get(expectedOtherTestTaskId), expectedOtherTask); + } + + private PluginTask createPluginTask( + String taskName, TypedInterface typedInterface, RetryStrategy retries) { + return new PluginTask() { + @Override + public String getName() { + return taskName; + } + + @Override + public String getType() { + return TASK_TYPE; + } + + @Override + public TypedInterface getInterface() { + return typedInterface; + } + + @Override + public RetryStrategy getRetries() { + return retries; + } + }; + } + + private void assertTaskEquals(PluginTask actualTask, PluginTask expectedTask) { + assertThat(actualTask.getName(), equalTo(expectedTask.getName())); + assertThat(actualTask.getType(), equalTo(expectedTask.getType())); + assertThat(actualTask.getCustom(), equalTo(expectedTask.getCustom())); + assertThat(actualTask.getInterface(), equalTo(expectedTask.getInterface())); + assertThat(actualTask.getRetries(), equalTo(expectedTask.getRetries())); + } + + @AutoService(SdkPluginTask.class) + public static class TestTask extends SdkPluginTask { + + private static final long serialVersionUID = 2751205856616541247L; + + public TestTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return TASK_TYPE; + } + } + + @AutoService(SdkPluginTask.class) + public static class OtherTestTask extends SdkPluginTask { + + private static final long serialVersionUID = -7757282344498000982L; + + public OtherTestTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return TASK_TYPE; + } + + @Override + public int getRetries() { + return 1; + } + } +} diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index fba0e4154..21ad821c7 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -64,6 +64,8 @@ import org.flyte.api.v1.Node; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; import org.flyte.api.v1.Resources; import org.flyte.api.v1.Resources.ResourceName; import org.flyte.api.v1.RunnableTask; @@ -135,6 +137,10 @@ public void serialize(BiConsumer output) { } private static TaskSpec applyCustom(TaskSpec taskSpec, JFlyteCustom custom) { + if (taskSpec.taskTemplate().container() == null) { + return taskSpec; + } + Struct rewrittenCustom = merge(custom.serializeToStruct(), taskSpec.taskTemplate().custom()); TaskTemplate rewrittenTaskTemplate = taskSpec.taskTemplate().toBuilder().custom(rewrittenCustom).build(); @@ -161,25 +167,26 @@ public static ProjectClosure loadAndStage( ProjectClosure closure = ProjectClosure.load(config, rewrite, packageClassLoader); - List artifacts; if (isStagingRequired(closure)) { - artifacts = stagePackageFiles(stagerSupplier.get(), packageDir); - } else { - artifacts = emptyList(); - LOG.info( - "Skipping artifact staging because there are no runnable tasks or dynamic workflow tasks"); + List artifacts = stagePackageFiles(stagerSupplier.get(), packageDir); + JFlyteCustom custom = JFlyteCustom.builder().artifacts(artifacts).build(); + return closure.applyCustom(custom); } - JFlyteCustom custom = JFlyteCustom.builder().artifacts(artifacts).build(); + LOG.info( + "Skipping artifact staging because there are no runnable tasks or dynamic workflow tasks"); - return closure.applyCustom(custom); + return closure; } private static boolean isStagingRequired(ProjectClosure closure) { return closure.taskSpecs().values().stream() .map(TaskSpec::taskTemplate) - .map(TaskTemplate::type) - .anyMatch(type -> !type.equals("raw-container")); + .anyMatch(ProjectClosure::isRunnableOrDynamicWorkflowTask); + } + + private static boolean isRunnableOrDynamicWorkflowTask(TaskTemplate taskTemplate) { + return taskTemplate.container() != null && !taskTemplate.type().equals("raw-container"); } private static List stagePackageFiles(ArtifactStager stager, String packageDir) { @@ -219,6 +226,10 @@ static ProjectClosure load( ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env)); + Map pluginTasks = + ClassLoaders.withClassLoader( + packageClassLoader, () -> Registrars.loadAll(PluginTaskRegistrar.class, env)); + Map workflows = ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(WorkflowTemplateRegistrar.class, env)); @@ -233,6 +244,7 @@ static ProjectClosure load( runnableTasks, dynamicWorkflowTasks, containerTasks, + pluginTasks, workflows, launchPlans); } @@ -243,10 +255,12 @@ static ProjectClosure load( Map runnableTasks, Map dynamicWorkflowTasks, Map containerTasks, + Map pluginTasks, Map workflowTemplates, Map launchPlans) { Map taskTemplates = - createTaskTemplates(config, runnableTasks, dynamicWorkflowTasks, containerTasks); + createTaskTemplates( + config, runnableTasks, dynamicWorkflowTasks, containerTasks, pluginTasks); // 2. rewrite workflows and launch plans Map rewrittenWorkflowTemplates = @@ -424,7 +438,8 @@ public static Map createTaskTemplates( ExecutionConfig config, Map runnableTasks, Map dynamicWorkflowTasks, - Map containerTasks) { + Map containerTasks, + Map pluginTasks) { Map taskTemplates = new HashMap<>(); runnableTasks.forEach( @@ -448,6 +463,13 @@ public static Map createTaskTemplates( taskTemplates.put(id, taskTemplate); }); + pluginTasks.forEach( + (id, task) -> { + TaskTemplate taskTemplate = createTaskTemplateForPluginTask(task); + + taskTemplates.put(id, taskTemplate); + }); + return taskTemplates; } @@ -473,7 +495,7 @@ static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String .resources(task.getResources()) .build(); - return createTaskTemplate(task, container); + return createTaskTemplateBuilder(task).container(container).build(); } @VisibleForTesting @@ -488,13 +510,17 @@ static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) { .resources(resources) .build(); - return createTaskTemplate(task, container); + return createTaskTemplateBuilder(task).container(container).build(); + } + + @VisibleForTesting + static TaskTemplate createTaskTemplateForPluginTask(PluginTask task) { + return createTaskTemplateBuilder(task).build(); } - private static TaskTemplate createTaskTemplate(Task task, Container container) { + private static TaskTemplate.Builder createTaskTemplateBuilder(Task task) { TaskTemplate.Builder templateBuilder = TaskTemplate.builder() - .container(container) .interface_(task.getInterface()) .retries(task.getRetries()) .type(task.getType()) @@ -506,7 +532,7 @@ private static TaskTemplate createTaskTemplate(Task task, Container container) { templateBuilder.discoveryVersion(task.getCacheVersion()); } - return templateBuilder.build(); + return templateBuilder; } private static Optional javaToolOptionsEnv(RunnableTask task) { diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index 3821db033..1d542b0b9 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -19,7 +19,6 @@ import static com.google.common.base.Strings.emptyToNull; import static com.google.common.base.Strings.nullToEmpty; import static java.time.format.DateTimeFormatter.ISO_DATE_TIME; -import static java.util.Objects.requireNonNull; import static org.flyte.jflyte.utils.MoreCollectors.mapValues; import static org.flyte.jflyte.utils.MoreCollectors.toUnmodifiableList; import static org.flyte.jflyte.utils.MoreCollectors.toUnmodifiableMap; @@ -45,6 +44,8 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; +import flyteidl.core.Tasks.RuntimeMetadata; +import flyteidl.core.Tasks.RuntimeMetadata.RuntimeType; import flyteidl.core.Tasks.TaskMetadata; import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; @@ -310,25 +311,24 @@ static TaskOuterClass.TaskSpec serialize(TaskSpec spec) { } static Tasks.TaskTemplate serialize(TaskTemplate taskTemplate) { - Container container = - requireNonNull( - taskTemplate.container(), "Only container based task templates are supported"); - TaskMetadata metadata = serializeTaskMetadata(taskTemplate); - return Tasks.TaskTemplate.newBuilder() - .setContainer(serialize(container)) - .setMetadata(metadata) - .setInterface(serialize(taskTemplate.interface_())) - .setType(taskTemplate.type()) - .setCustom(serializeStruct(taskTemplate.custom())) - .build(); + Tasks.TaskTemplate.Builder builder = + Tasks.TaskTemplate.newBuilder() + .setMetadata(metadata) + .setInterface(serialize(taskTemplate.interface_())) + .setType(taskTemplate.type()) + .setCustom(serializeStruct(taskTemplate.custom())); + if (taskTemplate.container() == null) { + return builder.build(); + } + return builder.setContainer(serialize(taskTemplate.container())).build(); } private static TaskMetadata serializeTaskMetadata(TaskTemplate taskTemplate) { Tasks.RuntimeMetadata runtime = - Tasks.RuntimeMetadata.newBuilder() - .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) + RuntimeMetadata.newBuilder() + .setType(RuntimeType.FLYTE_SDK) .setFlavor(RUNTIME_FLAVOR) .setVersion(RUNTIME_VERSION) .build(); diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java index cd1f698e7..d7b270da1 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java @@ -53,6 +53,7 @@ import org.flyte.api.v1.Operand; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; +import org.flyte.api.v1.PluginTask; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Resources; import org.flyte.api.v1.RetryStrategy; @@ -582,6 +583,27 @@ public void testCreateTaskTemplateForContainerTask() { assertThat(result.type(), equalTo("raw-container")); } + @Test + public void testCreateTaskTemplateForPluginTask() { + // given + PluginTask task = createPluginTask(); + + // when + TaskTemplate result = ProjectClosure.createTaskTemplateForPluginTask(task); + + // then + assertThat( + result.interface_(), + equalTo( + TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build())); + assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); + assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); + assertThat(result.type(), equalTo("test-plugin-task")); + } + @Test public void testCreateTaskTemplateForTasksWithDefaultCacheSettings() { // given @@ -829,6 +851,33 @@ public List getEnv() { }; } + private PluginTask createPluginTask() { + return new PluginTask() { + @Override + public String getName() { + return "foo"; + } + + @Override + public String getType() { + return "test-plugin-task"; + } + + @Override + public TypedInterface getInterface() { + return TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build(); + } + + @Override + public RetryStrategy getRetries() { + return RetryStrategy.builder().retries(0).build(); + } + }; + } + private T wrapTaskWithRetries(Class taskClass, T task, boolean cacheEnabled) { return Reflection.newProxy( taskClass, diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index c011611dc..b46c44895 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -44,6 +44,8 @@ import org.flyte.api.v1.Literal; import org.flyte.api.v1.NamedEntityIdentifier; import org.flyte.api.v1.Node; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.RunnableTaskRegistrar; import org.flyte.api.v1.Struct; @@ -150,6 +152,10 @@ private void execute() { ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env)); + Map pluginTasks = + ClassLoaders.withClassLoader( + packageClassLoader, () -> Registrars.loadAll(PluginTaskRegistrar.class, env)); + // before we run anything, switch class loader, otherwise, // ServiceLoaders and other things wouldn't work, for instance, // FileSystemRegister in Apache Beam @@ -162,7 +168,11 @@ private void execute() { Map taskTemplates = mapValues( ProjectClosure.createTaskTemplates( - executionConfig, runnableTasks, dynamicWorkflowTasks, containerTasks), + executionConfig, + runnableTasks, + dynamicWorkflowTasks, + containerTasks, + pluginTasks), template -> template.toBuilder() .custom(ProjectClosure.merge(template.custom(), custom))